Skip to main content

ati/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::State,
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34/// Shared state for the proxy server.
35pub struct ProxyState {
36    pub registry: ManifestRegistry,
37    pub skill_registry: SkillRegistry,
38    pub keyring: Keyring,
39    pub verbose: bool,
40    /// JWT validation config (None = auth disabled / dev mode).
41    pub jwt_config: Option<JwtConfig>,
42    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
43    pub jwks_json: Option<Value>,
44    /// Shared cache for dynamically generated auth credentials.
45    pub auth_cache: AuthCache,
46}
47
48// --- Request/Response types ---
49
50#[derive(Debug, Deserialize)]
51pub struct CallRequest {
52    pub tool_name: String,
53    /// Tool arguments — accepts a JSON object (key-value pairs) for HTTP/MCP/OpenAPI tools,
54    /// or a JSON array of strings / a single string for CLI tools.
55    /// The proxy auto-detects the handler type and routes accordingly.
56    #[serde(default = "default_args")]
57    pub args: Value,
58    /// Deprecated: use `args` with an array value instead.
59    /// Kept for backward compatibility — if present, takes precedence for CLI tools.
60    #[serde(default)]
61    pub raw_args: Option<Vec<String>>,
62}
63
64fn default_args() -> Value {
65    Value::Object(serde_json::Map::new())
66}
67
68impl CallRequest {
69    /// Extract args as a HashMap for HTTP/MCP/OpenAPI tools.
70    /// If `args` is a JSON object, returns its entries.
71    /// If `args` is something else (array, string), returns an empty map.
72    fn args_as_map(&self) -> HashMap<String, Value> {
73        match &self.args {
74            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
75            _ => HashMap::new(),
76        }
77    }
78
79    /// Extract positional args for CLI tools.
80    /// Priority: explicit `raw_args` field > `args` array > `args` string > `args._positional` > empty.
81    fn args_as_positional(&self) -> Vec<String> {
82        // Backward compat: explicit raw_args wins
83        if let Some(ref raw) = self.raw_args {
84            return raw.clone();
85        }
86        match &self.args {
87            // ["pr", "list", "--repo", "X"]
88            Value::Array(arr) => arr
89                .iter()
90                .filter_map(|v| match v {
91                    Value::String(s) => Some(s.clone()),
92                    other => Some(other.to_string()),
93                })
94                .collect(),
95            // "pr list --repo X"
96            Value::String(s) => s.split_whitespace().map(String::from).collect(),
97            // {"_positional": ["pr", "list"]} or {"--key": "value"} converted to CLI flags
98            Value::Object(map) => {
99                if let Some(Value::Array(pos)) = map.get("_positional") {
100                    return pos
101                        .iter()
102                        .filter_map(|v| match v {
103                            Value::String(s) => Some(s.clone()),
104                            other => Some(other.to_string()),
105                        })
106                        .collect();
107                }
108                // Convert map entries to --key value pairs
109                let mut result = Vec::new();
110                for (k, v) in map {
111                    result.push(format!("--{k}"));
112                    match v {
113                        Value::String(s) => result.push(s.clone()),
114                        Value::Bool(true) => {} // flag, no value needed
115                        other => result.push(other.to_string()),
116                    }
117                }
118                result
119            }
120            _ => Vec::new(),
121        }
122    }
123}
124
125#[derive(Debug, Serialize)]
126pub struct CallResponse {
127    pub result: Value,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub error: Option<String>,
130}
131
132#[derive(Debug, Deserialize)]
133pub struct HelpRequest {
134    pub query: String,
135    #[serde(default)]
136    pub tool: Option<String>,
137}
138
139#[derive(Debug, Serialize)]
140pub struct HelpResponse {
141    pub content: String,
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub error: Option<String>,
144}
145
146#[derive(Debug, Serialize)]
147pub struct HealthResponse {
148    pub status: String,
149    pub version: String,
150    pub tools: usize,
151    pub providers: usize,
152    pub skills: usize,
153    pub auth: String,
154}
155
156// --- Skill endpoint types ---
157
158#[derive(Debug, Deserialize)]
159pub struct SkillsQuery {
160    #[serde(default)]
161    pub category: Option<String>,
162    #[serde(default)]
163    pub provider: Option<String>,
164    #[serde(default)]
165    pub tool: Option<String>,
166    #[serde(default)]
167    pub search: Option<String>,
168}
169
170#[derive(Debug, Deserialize)]
171pub struct SkillDetailQuery {
172    #[serde(default)]
173    pub meta: Option<bool>,
174    #[serde(default)]
175    pub refs: Option<bool>,
176}
177
178#[derive(Debug, Deserialize)]
179pub struct SkillResolveRequest {
180    pub scopes: Vec<String>,
181}
182
183// --- Handlers ---
184
185async fn handle_call(
186    State(state): State<Arc<ProxyState>>,
187    req: HttpRequest<Body>,
188) -> impl IntoResponse {
189    // Extract JWT claims from request extensions (set by auth middleware)
190    let claims = req.extensions().get::<TokenClaims>().cloned();
191
192    // Parse request body
193    let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
194        Ok(b) => b,
195        Err(e) => {
196            return (
197                StatusCode::BAD_REQUEST,
198                Json(CallResponse {
199                    result: Value::Null,
200                    error: Some(format!("Failed to read request body: {e}")),
201                }),
202            );
203        }
204    };
205
206    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
207        Ok(r) => r,
208        Err(e) => {
209            return (
210                StatusCode::UNPROCESSABLE_ENTITY,
211                Json(CallResponse {
212                    result: Value::Null,
213                    error: Some(format!("Invalid request: {e}")),
214                }),
215            );
216        }
217    };
218
219    if state.verbose {
220        eprintln!(
221            "[proxy] /call tool={} args={:?}",
222            call_req.tool_name, call_req.args
223        );
224    }
225
226    // Look up tool in registry
227    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
228        Some(pt) => pt,
229        None => {
230            return (
231                StatusCode::NOT_FOUND,
232                Json(CallResponse {
233                    result: Value::Null,
234                    error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
235                }),
236            );
237        }
238    };
239
240    // Scope enforcement from JWT claims
241    if let Some(tool_scope) = &tool.scope {
242        let scopes = match &claims {
243            Some(c) => ScopeConfig::from_jwt(c),
244            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
245            None => {
246                return (
247                    StatusCode::FORBIDDEN,
248                    Json(CallResponse {
249                        result: Value::Null,
250                        error: Some("Authentication required — no JWT provided".into()),
251                    }),
252                );
253            }
254        };
255
256        if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
257            return (
258                StatusCode::FORBIDDEN,
259                Json(CallResponse {
260                    result: Value::Null,
261                    error: Some(format!("Access denied: {e}")),
262                }),
263            );
264        }
265    }
266
267    // Rate limit check
268    {
269        let scopes = match &claims {
270            Some(c) => ScopeConfig::from_jwt(c),
271            None => ScopeConfig::unrestricted(),
272        };
273        if let Some(ref rate_config) = scopes.rate_config {
274            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
275                return (
276                    StatusCode::TOO_MANY_REQUESTS,
277                    Json(CallResponse {
278                        result: Value::Null,
279                        error: Some(format!("{e}")),
280                    }),
281                );
282            }
283        }
284    }
285
286    // Build auth generator context from JWT claims
287    let gen_ctx = GenContext {
288        jwt_sub: claims
289            .as_ref()
290            .map(|c| c.sub.clone())
291            .unwrap_or_else(|| "dev".into()),
292        jwt_scope: claims
293            .as_ref()
294            .map(|c| c.scope.clone())
295            .unwrap_or_else(|| "*".into()),
296        tool_name: call_req.tool_name.clone(),
297        timestamp: crate::core::jwt::now_secs(),
298    };
299
300    // Execute tool call — dispatch based on handler type, with timing for audit
301    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
302    let start = std::time::Instant::now();
303
304    let response = match provider.handler.as_str() {
305        "mcp" => {
306            let args_map = call_req.args_as_map();
307            match mcp_client::execute_with_gen(
308                provider,
309                &call_req.tool_name,
310                &args_map,
311                &state.keyring,
312                Some(&gen_ctx),
313                Some(&state.auth_cache),
314            )
315            .await
316            {
317                Ok(result) => (
318                    StatusCode::OK,
319                    Json(CallResponse {
320                        result,
321                        error: None,
322                    }),
323                ),
324                Err(e) => (
325                    StatusCode::BAD_GATEWAY,
326                    Json(CallResponse {
327                        result: Value::Null,
328                        error: Some(format!("MCP error: {e}")),
329                    }),
330                ),
331            }
332        }
333        "cli" => {
334            let positional = call_req.args_as_positional();
335            match crate::core::cli_executor::execute_with_gen(
336                provider,
337                &positional,
338                &state.keyring,
339                Some(&gen_ctx),
340                Some(&state.auth_cache),
341            )
342            .await
343            {
344                Ok(result) => (
345                    StatusCode::OK,
346                    Json(CallResponse {
347                        result,
348                        error: None,
349                    }),
350                ),
351                Err(e) => (
352                    StatusCode::BAD_GATEWAY,
353                    Json(CallResponse {
354                        result: Value::Null,
355                        error: Some(format!("CLI error: {e}")),
356                    }),
357                ),
358            }
359        }
360        _ => {
361            let args_map = call_req.args_as_map();
362            let raw_response = match match provider.handler.as_str() {
363                "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
364                _ => {
365                    http::execute_tool_with_gen(
366                        provider,
367                        tool,
368                        &args_map,
369                        &state.keyring,
370                        Some(&gen_ctx),
371                        Some(&state.auth_cache),
372                    )
373                    .await
374                }
375            } {
376                Ok(resp) => resp,
377                Err(e) => {
378                    let duration = start.elapsed();
379                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
380                    return (
381                        StatusCode::BAD_GATEWAY,
382                        Json(CallResponse {
383                            result: Value::Null,
384                            error: Some(format!("Upstream API error: {e}")),
385                        }),
386                    );
387                }
388            };
389
390            let processed = match response::process_response(&raw_response, tool.response.as_ref())
391            {
392                Ok(p) => p,
393                Err(e) => {
394                    let duration = start.elapsed();
395                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
396                    return (
397                        StatusCode::INTERNAL_SERVER_ERROR,
398                        Json(CallResponse {
399                            result: raw_response,
400                            error: Some(format!("Response processing error: {e}")),
401                        }),
402                    );
403                }
404            };
405
406            (
407                StatusCode::OK,
408                Json(CallResponse {
409                    result: processed,
410                    error: None,
411                }),
412            )
413        }
414    };
415
416    let duration = start.elapsed();
417    let error_msg = response.1.error.as_deref();
418    write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
419
420    response
421}
422
423async fn handle_help(
424    State(state): State<Arc<ProxyState>>,
425    Json(req): Json<HelpRequest>,
426) -> impl IntoResponse {
427    if state.verbose {
428        eprintln!("[proxy] /help query={} tool={:?}", req.query, req.tool);
429    }
430
431    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
432        Some(pt) => pt,
433        None => {
434            return (
435                StatusCode::SERVICE_UNAVAILABLE,
436                Json(HelpResponse {
437                    content: String::new(),
438                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
439                }),
440            );
441        }
442    };
443
444    let api_key = match llm_provider
445        .auth_key_name
446        .as_deref()
447        .and_then(|k| state.keyring.get(k))
448    {
449        Some(key) => key.to_string(),
450        None => {
451            return (
452                StatusCode::SERVICE_UNAVAILABLE,
453                Json(HelpResponse {
454                    content: String::new(),
455                    error: Some("LLM API key not found in keyring".into()),
456                }),
457            );
458        }
459    };
460
461    let scopes = ScopeConfig::unrestricted();
462    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
463    let skills_section = if resolved_skills.is_empty() {
464        String::new()
465    } else {
466        format!(
467            "## Available Skills (methodology guides)\n{}",
468            skill::build_skill_context(&resolved_skills)
469        )
470    };
471
472    // Build system prompt — scoped or unscoped
473    let system_prompt = if let Some(ref tool_name) = req.tool {
474        // Scoped mode: narrow tools to the specified tool or provider
475        match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
476            Some(prompt) => prompt,
477            None => {
478                // Fall back to unscoped if tool/provider not found
479                if state.verbose {
480                    eprintln!(
481                        "[proxy] /help scope '{}' not found, falling back to unscoped",
482                        tool_name
483                    );
484                }
485                let all_tools = state.registry.list_public_tools();
486                let tools_context = build_tool_context(&all_tools);
487                HELP_SYSTEM_PROMPT
488                    .replace("{tools}", &tools_context)
489                    .replace("{skills_section}", &skills_section)
490            }
491        }
492    } else {
493        let all_tools = state.registry.list_public_tools();
494        let tools_context = build_tool_context(&all_tools);
495        HELP_SYSTEM_PROMPT
496            .replace("{tools}", &tools_context)
497            .replace("{skills_section}", &skills_section)
498    };
499
500    let request_body = serde_json::json!({
501        "model": "zai-glm-4.7",
502        "messages": [
503            {"role": "system", "content": system_prompt},
504            {"role": "user", "content": req.query}
505        ],
506        "max_completion_tokens": 1536,
507        "temperature": 0.3
508    });
509
510    let client = reqwest::Client::new();
511    let url = format!(
512        "{}{}",
513        llm_provider.base_url.trim_end_matches('/'),
514        llm_tool.endpoint
515    );
516
517    let response = match client
518        .post(&url)
519        .bearer_auth(&api_key)
520        .json(&request_body)
521        .send()
522        .await
523    {
524        Ok(r) => r,
525        Err(e) => {
526            return (
527                StatusCode::BAD_GATEWAY,
528                Json(HelpResponse {
529                    content: String::new(),
530                    error: Some(format!("LLM request failed: {e}")),
531                }),
532            );
533        }
534    };
535
536    if !response.status().is_success() {
537        let status = response.status();
538        let body = response.text().await.unwrap_or_default();
539        return (
540            StatusCode::BAD_GATEWAY,
541            Json(HelpResponse {
542                content: String::new(),
543                error: Some(format!("LLM API error ({status}): {body}")),
544            }),
545        );
546    }
547
548    let body: Value = match response.json().await {
549        Ok(b) => b,
550        Err(e) => {
551            return (
552                StatusCode::INTERNAL_SERVER_ERROR,
553                Json(HelpResponse {
554                    content: String::new(),
555                    error: Some(format!("Failed to parse LLM response: {e}")),
556                }),
557            );
558        }
559    };
560
561    let content = body
562        .pointer("/choices/0/message/content")
563        .and_then(|c| c.as_str())
564        .unwrap_or("No response from LLM")
565        .to_string();
566
567    (
568        StatusCode::OK,
569        Json(HelpResponse {
570            content,
571            error: None,
572        }),
573    )
574}
575
576async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
577    let auth = if state.jwt_config.is_some() {
578        "jwt"
579    } else {
580        "disabled"
581    };
582
583    Json(HealthResponse {
584        status: "ok".into(),
585        version: env!("CARGO_PKG_VERSION").into(),
586        tools: state.registry.list_public_tools().len(),
587        providers: state.registry.list_providers().len(),
588        skills: state.skill_registry.skill_count(),
589        auth: auth.into(),
590    })
591}
592
593/// GET /.well-known/jwks.json — serves the public key for JWT validation.
594async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
595    match &state.jwks_json {
596        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
597        None => (
598            StatusCode::NOT_FOUND,
599            Json(serde_json::json!({"error": "JWKS not configured"})),
600        ),
601    }
602}
603
604// ---------------------------------------------------------------------------
605// POST /mcp — MCP JSON-RPC proxy endpoint
606// ---------------------------------------------------------------------------
607
608async fn handle_mcp(
609    State(state): State<Arc<ProxyState>>,
610    Json(msg): Json<Value>,
611) -> impl IntoResponse {
612    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
613    let id = msg.get("id").cloned();
614
615    if state.verbose {
616        eprintln!("[proxy] /mcp method={method}");
617    }
618
619    match method {
620        "initialize" => {
621            let result = serde_json::json!({
622                "protocolVersion": "2025-03-26",
623                "capabilities": {
624                    "tools": { "listChanged": false }
625                },
626                "serverInfo": {
627                    "name": "ati-proxy",
628                    "version": env!("CARGO_PKG_VERSION")
629                }
630            });
631            jsonrpc_success(id, result)
632        }
633
634        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
635
636        "tools/list" => {
637            let all_tools = state.registry.list_public_tools();
638            let mcp_tools: Vec<Value> = all_tools
639                .iter()
640                .map(|(_provider, tool)| {
641                    serde_json::json!({
642                        "name": tool.name,
643                        "description": tool.description,
644                        "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
645                            "type": "object",
646                            "properties": {}
647                        }))
648                    })
649                })
650                .collect();
651
652            let result = serde_json::json!({
653                "tools": mcp_tools,
654            });
655            jsonrpc_success(id, result)
656        }
657
658        "tools/call" => {
659            let params = msg.get("params").cloned().unwrap_or(Value::Null);
660            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
661            let arguments: HashMap<String, Value> = params
662                .get("arguments")
663                .and_then(|a| serde_json::from_value(a.clone()).ok())
664                .unwrap_or_default();
665
666            if tool_name.is_empty() {
667                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
668            }
669
670            let (provider, _tool) = match state.registry.get_tool(tool_name) {
671                Some(pt) => pt,
672                None => {
673                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
674                }
675            };
676
677            if state.verbose {
678                eprintln!(
679                    "[proxy] /mcp tools/call name={tool_name} provider={}",
680                    provider.name
681                );
682            }
683
684            let mcp_gen_ctx = GenContext {
685                jwt_sub: "dev".into(),
686                jwt_scope: "*".into(),
687                tool_name: tool_name.to_string(),
688                timestamp: crate::core::jwt::now_secs(),
689            };
690
691            let result = if provider.is_mcp() {
692                mcp_client::execute_with_gen(
693                    provider,
694                    tool_name,
695                    &arguments,
696                    &state.keyring,
697                    Some(&mcp_gen_ctx),
698                    Some(&state.auth_cache),
699                )
700                .await
701            } else if provider.is_cli() {
702                // Convert arguments map to CLI-style args for MCP passthrough
703                let raw: Vec<String> = arguments
704                    .iter()
705                    .flat_map(|(k, v)| {
706                        let val = match v {
707                            Value::String(s) => s.clone(),
708                            other => other.to_string(),
709                        };
710                        vec![format!("--{k}"), val]
711                    })
712                    .collect();
713                crate::core::cli_executor::execute_with_gen(
714                    provider,
715                    &raw,
716                    &state.keyring,
717                    Some(&mcp_gen_ctx),
718                    Some(&state.auth_cache),
719                )
720                .await
721                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
722            } else {
723                match match provider.handler.as_str() {
724                    "xai" => {
725                        xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
726                    }
727                    _ => {
728                        http::execute_tool_with_gen(
729                            provider,
730                            _tool,
731                            &arguments,
732                            &state.keyring,
733                            Some(&mcp_gen_ctx),
734                            Some(&state.auth_cache),
735                        )
736                        .await
737                    }
738                } {
739                    Ok(val) => Ok(val),
740                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
741                }
742            };
743
744            match result {
745                Ok(value) => {
746                    let text = match &value {
747                        Value::String(s) => s.clone(),
748                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
749                    };
750                    let mcp_result = serde_json::json!({
751                        "content": [{"type": "text", "text": text}],
752                        "isError": false,
753                    });
754                    jsonrpc_success(id, mcp_result)
755                }
756                Err(e) => {
757                    let mcp_result = serde_json::json!({
758                        "content": [{"type": "text", "text": format!("Error: {e}")}],
759                        "isError": true,
760                    });
761                    jsonrpc_success(id, mcp_result)
762                }
763            }
764        }
765
766        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
767    }
768}
769
770fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
771    (
772        StatusCode::OK,
773        Json(serde_json::json!({
774            "jsonrpc": "2.0",
775            "id": id,
776            "result": result,
777        })),
778    )
779}
780
781fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
782    (
783        StatusCode::OK,
784        Json(serde_json::json!({
785            "jsonrpc": "2.0",
786            "id": id,
787            "error": {
788                "code": code,
789                "message": message,
790            }
791        })),
792    )
793}
794
795// ---------------------------------------------------------------------------
796// Skill endpoints
797// ---------------------------------------------------------------------------
798
799async fn handle_skills_list(
800    State(state): State<Arc<ProxyState>>,
801    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
802) -> impl IntoResponse {
803    if state.verbose {
804        eprintln!(
805            "[proxy] GET /skills category={:?} provider={:?} tool={:?} search={:?}",
806            query.category, query.provider, query.tool, query.search
807        );
808    }
809
810    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
811        state.skill_registry.search(search_query)
812    } else if let Some(cat) = &query.category {
813        state.skill_registry.skills_for_category(cat)
814    } else if let Some(prov) = &query.provider {
815        state.skill_registry.skills_for_provider(prov)
816    } else if let Some(t) = &query.tool {
817        state.skill_registry.skills_for_tool(t)
818    } else {
819        state.skill_registry.list_skills().iter().collect()
820    };
821
822    let json: Vec<Value> = skills
823        .iter()
824        .map(|s| {
825            serde_json::json!({
826                "name": s.name,
827                "version": s.version,
828                "description": s.description,
829                "tools": s.tools,
830                "providers": s.providers,
831                "categories": s.categories,
832                "hint": s.hint,
833            })
834        })
835        .collect();
836
837    (StatusCode::OK, Json(Value::Array(json)))
838}
839
840async fn handle_skill_detail(
841    State(state): State<Arc<ProxyState>>,
842    axum::extract::Path(name): axum::extract::Path<String>,
843    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
844) -> impl IntoResponse {
845    if state.verbose {
846        eprintln!(
847            "[proxy] GET /skills/{name} meta={:?} refs={:?}",
848            query.meta, query.refs
849        );
850    }
851
852    let skill_meta = match state.skill_registry.get_skill(&name) {
853        Some(s) => s,
854        None => {
855            return (
856                StatusCode::NOT_FOUND,
857                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
858            );
859        }
860    };
861
862    if query.meta.unwrap_or(false) {
863        return (
864            StatusCode::OK,
865            Json(serde_json::json!({
866                "name": skill_meta.name,
867                "version": skill_meta.version,
868                "description": skill_meta.description,
869                "author": skill_meta.author,
870                "tools": skill_meta.tools,
871                "providers": skill_meta.providers,
872                "categories": skill_meta.categories,
873                "keywords": skill_meta.keywords,
874                "hint": skill_meta.hint,
875                "depends_on": skill_meta.depends_on,
876                "suggests": skill_meta.suggests,
877                "license": skill_meta.license,
878                "compatibility": skill_meta.compatibility,
879                "allowed_tools": skill_meta.allowed_tools,
880                "format": skill_meta.format,
881            })),
882        );
883    }
884
885    let content = match state.skill_registry.read_content(&name) {
886        Ok(c) => c,
887        Err(e) => {
888            return (
889                StatusCode::INTERNAL_SERVER_ERROR,
890                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
891            );
892        }
893    };
894
895    let mut response = serde_json::json!({
896        "name": skill_meta.name,
897        "version": skill_meta.version,
898        "description": skill_meta.description,
899        "content": content,
900    });
901
902    if query.refs.unwrap_or(false) {
903        if let Ok(refs) = state.skill_registry.list_references(&name) {
904            response["references"] = serde_json::json!(refs);
905        }
906    }
907
908    (StatusCode::OK, Json(response))
909}
910
911async fn handle_skills_resolve(
912    State(state): State<Arc<ProxyState>>,
913    Json(req): Json<SkillResolveRequest>,
914) -> impl IntoResponse {
915    if state.verbose {
916        eprintln!("[proxy] POST /skills/resolve scopes={:?}", req.scopes);
917    }
918
919    let scopes = ScopeConfig {
920        scopes: req.scopes,
921        sub: String::new(),
922        expires_at: 0,
923        rate_config: None,
924    };
925
926    let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
927
928    let json: Vec<Value> = resolved
929        .iter()
930        .map(|s| {
931            serde_json::json!({
932                "name": s.name,
933                "version": s.version,
934                "description": s.description,
935                "tools": s.tools,
936                "providers": s.providers,
937                "categories": s.categories,
938            })
939        })
940        .collect();
941
942    (StatusCode::OK, Json(Value::Array(json)))
943}
944
945// --- Auth middleware ---
946
947/// JWT authentication middleware.
948///
949/// - /health and /.well-known/jwks.json → skip auth
950/// - JWT configured → validate Bearer token, attach claims to request extensions
951/// - No JWT configured → allow all (dev mode)
952async fn auth_middleware(
953    State(state): State<Arc<ProxyState>>,
954    mut req: HttpRequest<Body>,
955    next: Next,
956) -> Result<Response, StatusCode> {
957    let path = req.uri().path();
958
959    // Skip auth for public endpoints
960    if path == "/health" || path == "/.well-known/jwks.json" {
961        return Ok(next.run(req).await);
962    }
963
964    // If no JWT configured, allow all (dev mode)
965    let jwt_config = match &state.jwt_config {
966        Some(c) => c,
967        None => return Ok(next.run(req).await),
968    };
969
970    // Extract Authorization: Bearer <token>
971    let auth_header = req
972        .headers()
973        .get("authorization")
974        .and_then(|v| v.to_str().ok());
975
976    let token = match auth_header {
977        Some(header) if header.starts_with("Bearer ") => &header[7..],
978        _ => return Err(StatusCode::UNAUTHORIZED),
979    };
980
981    // Validate JWT
982    match jwt::validate(token, jwt_config) {
983        Ok(claims) => {
984            if state.verbose {
985                eprintln!(
986                    "[proxy] JWT validated: sub={} scopes={}",
987                    claims.sub, claims.scope
988                );
989            }
990            req.extensions_mut().insert(claims);
991            Ok(next.run(req).await)
992        }
993        Err(e) => {
994            if state.verbose {
995                eprintln!("[proxy] JWT validation failed: {e}");
996            }
997            Err(StatusCode::UNAUTHORIZED)
998        }
999    }
1000}
1001
1002// --- Router builder ---
1003
1004/// Build the axum Router from a pre-constructed ProxyState.
1005pub fn build_router(state: Arc<ProxyState>) -> Router {
1006    Router::new()
1007        .route("/call", post(handle_call))
1008        .route("/help", post(handle_help))
1009        .route("/mcp", post(handle_mcp))
1010        .route("/skills", get(handle_skills_list))
1011        .route("/skills/resolve", post(handle_skills_resolve))
1012        .route("/skills/{name}", get(handle_skill_detail))
1013        .route("/health", get(handle_health))
1014        .route("/.well-known/jwks.json", get(handle_jwks))
1015        .layer(middleware::from_fn_with_state(
1016            state.clone(),
1017            auth_middleware,
1018        ))
1019        .with_state(state)
1020}
1021
1022// --- Server startup ---
1023
1024/// Start the proxy server.
1025pub async fn run(
1026    port: u16,
1027    bind_addr: Option<String>,
1028    ati_dir: PathBuf,
1029    verbose: bool,
1030    env_keys: bool,
1031) -> Result<(), Box<dyn std::error::Error>> {
1032    // Load manifests
1033    let manifests_dir = ati_dir.join("manifests");
1034    let registry = ManifestRegistry::load(&manifests_dir)?;
1035
1036    let tool_count = registry.list_public_tools().len();
1037    let provider_count = registry.list_providers().len();
1038
1039    // Load keyring
1040    let keyring_source;
1041    let keyring = if env_keys {
1042        // --env-keys: scan ATI_KEY_* environment variables
1043        let kr = Keyring::from_env();
1044        let key_names = kr.key_names();
1045        eprintln!(
1046            "  Loaded {} API keys from ATI_KEY_* env vars:",
1047            key_names.len()
1048        );
1049        for name in &key_names {
1050            eprintln!("    - {name}");
1051        }
1052        keyring_source = "env-vars (ATI_KEY_*)";
1053        kr
1054    } else {
1055        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
1056        let keyring_path = ati_dir.join("keyring.enc");
1057        if keyring_path.exists() {
1058            if let Ok(kr) = Keyring::load(&keyring_path) {
1059                keyring_source = "keyring.enc (sealed key)";
1060                kr
1061            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1062                keyring_source = "keyring.enc (persistent key)";
1063                kr
1064            } else {
1065                eprintln!("Warning: keyring.enc exists but could not be decrypted");
1066                keyring_source = "empty (decryption failed)";
1067                Keyring::empty()
1068            }
1069        } else {
1070            let creds_path = ati_dir.join("credentials");
1071            if creds_path.exists() {
1072                match Keyring::load_credentials(&creds_path) {
1073                    Ok(kr) => {
1074                        keyring_source = "credentials (plaintext)";
1075                        kr
1076                    }
1077                    Err(e) => {
1078                        eprintln!("Warning: Failed to load credentials: {e}");
1079                        keyring_source = "empty (credentials error)";
1080                        Keyring::empty()
1081                    }
1082                }
1083            } else {
1084                eprintln!(
1085                    "Warning: No keyring.enc or credentials found — running without API keys"
1086                );
1087                eprintln!("Tools requiring authentication will fail.");
1088                keyring_source = "empty (no auth)";
1089                Keyring::empty()
1090            }
1091        }
1092    };
1093
1094    // Log MCP and OpenAPI providers
1095    let mcp_providers: Vec<(String, String)> = registry
1096        .list_mcp_providers()
1097        .iter()
1098        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1099        .collect();
1100    let mcp_count = mcp_providers.len();
1101    let openapi_providers: Vec<String> = registry
1102        .list_openapi_providers()
1103        .iter()
1104        .map(|p| p.name.clone())
1105        .collect();
1106    let openapi_count = openapi_providers.len();
1107
1108    // Load skill registry
1109    let skills_dir = ati_dir.join("skills");
1110    let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1111        if verbose {
1112            eprintln!("Warning: failed to load skills: {e}");
1113        }
1114        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1115    });
1116    let skill_count = skill_registry.skill_count();
1117
1118    // Load JWT config from environment
1119    let jwt_config = match jwt::config_from_env() {
1120        Ok(config) => config,
1121        Err(e) => {
1122            eprintln!("Warning: JWT config error: {e}");
1123            None
1124        }
1125    };
1126
1127    let auth_status = if jwt_config.is_some() {
1128        "JWT enabled"
1129    } else {
1130        "DISABLED (no JWT keys configured)"
1131    };
1132
1133    // Build JWKS for the endpoint
1134    let jwks_json = jwt_config.as_ref().and_then(|config| {
1135        config
1136            .public_key_pem
1137            .as_ref()
1138            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1139    });
1140
1141    let state = Arc::new(ProxyState {
1142        registry,
1143        skill_registry,
1144        keyring,
1145        verbose,
1146        jwt_config,
1147        jwks_json,
1148        auth_cache: AuthCache::new(),
1149    });
1150
1151    let app = build_router(state);
1152
1153    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1154        format!("{bind}:{port}").parse()?
1155    } else {
1156        SocketAddr::from(([127, 0, 0, 1], port))
1157    };
1158
1159    eprintln!("ATI proxy server v{}", env!("CARGO_PKG_VERSION"));
1160    eprintln!("  Listening on http://{addr}");
1161    eprintln!("  Auth: {auth_status}");
1162    eprintln!("  ATI dir: {}", ati_dir.display());
1163    eprintln!("  Tools: {tool_count}, Providers: {provider_count} ({mcp_count} MCP, {openapi_count} OpenAPI)");
1164    eprintln!("  Skills: {skill_count}");
1165    eprintln!("  Keyring: {keyring_source}");
1166    eprintln!("  Endpoints: /call, /help, /mcp, /skills, /skills/:name, /skills/resolve, /health, /.well-known/jwks.json");
1167    for (name, transport) in &mcp_providers {
1168        eprintln!("  MCP: {name} ({transport})");
1169    }
1170    for name in &openapi_providers {
1171        eprintln!("  OpenAPI: {name}");
1172    }
1173
1174    let listener = tokio::net::TcpListener::bind(addr).await?;
1175    axum::serve(listener, app).await?;
1176
1177    Ok(())
1178}
1179
1180/// Write an audit entry from the proxy server. Failures are silently ignored.
1181fn write_proxy_audit(
1182    call_req: &CallRequest,
1183    agent_sub: &str,
1184    duration: std::time::Duration,
1185    error: Option<&str>,
1186) {
1187    let entry = crate::core::audit::AuditEntry {
1188        ts: chrono::Utc::now().to_rfc3339(),
1189        tool: call_req.tool_name.clone(),
1190        args: crate::core::audit::sanitize_args(&call_req.args),
1191        status: if error.is_some() {
1192            crate::core::audit::AuditStatus::Error
1193        } else {
1194            crate::core::audit::AuditStatus::Ok
1195        },
1196        duration_ms: duration.as_millis() as u64,
1197        agent_sub: agent_sub.to_string(),
1198        error: error.map(|s| s.to_string()),
1199        exit_code: None,
1200    };
1201    let _ = crate::core::audit::append(&entry);
1202}
1203
1204// --- Helpers ---
1205
1206const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1207
1208## Available Tools
1209{tools}
1210
1211{skills_section}
1212
1213Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1214
1215- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1216- If multiple steps are needed, walk through them briefly in order
1217- Mention important gotchas or parameter choices that matter
1218- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1219
1220Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1221
1222fn build_tool_context(
1223    tools: &[(
1224        &crate::core::manifest::Provider,
1225        &crate::core::manifest::Tool,
1226    )],
1227) -> String {
1228    let mut summaries = Vec::new();
1229    for (provider, tool) in tools {
1230        let mut summary = if let Some(cat) = &provider.category {
1231            format!(
1232                "- **{}** (provider: {}, category: {}): {}",
1233                tool.name, provider.name, cat, tool.description
1234            )
1235        } else {
1236            format!(
1237                "- **{}** (provider: {}): {}",
1238                tool.name, provider.name, tool.description
1239            )
1240        };
1241        if !tool.tags.is_empty() {
1242            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
1243        }
1244        // CLI tools: show passthrough usage
1245        if provider.is_cli() && tool.input_schema.is_none() {
1246            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1247            summary.push_str(&format!(
1248                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
1249                tool.name, cmd
1250            ));
1251        } else if let Some(schema) = &tool.input_schema {
1252            if let Some(props) = schema.get("properties") {
1253                if let Some(obj) = props.as_object() {
1254                    let params: Vec<String> = obj
1255                        .iter()
1256                        .filter(|(_, v)| {
1257                            v.get("x-ati-param-location").is_none()
1258                                || v.get("description").is_some()
1259                        })
1260                        .map(|(k, v)| {
1261                            let type_str =
1262                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1263                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1264                            format!("    --{k} ({type_str}): {desc}")
1265                        })
1266                        .collect();
1267                    if !params.is_empty() {
1268                        summary.push_str("\n  Parameters:\n");
1269                        summary.push_str(&params.join("\n"));
1270                    }
1271                }
1272            }
1273        }
1274        summaries.push(summary);
1275    }
1276    summaries.join("\n\n")
1277}
1278
1279/// Build a scoped system prompt for a specific tool or provider.
1280///
1281/// Returns None if the scope_name doesn't match any tool or provider.
1282fn build_scoped_prompt(
1283    scope_name: &str,
1284    registry: &ManifestRegistry,
1285    skills_section: &str,
1286) -> Option<String> {
1287    // Check if scope_name is a tool
1288    if let Some((provider, tool)) = registry.get_tool(scope_name) {
1289        let mut details = format!(
1290            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1291            tool.name, provider.name, provider.handler, tool.description
1292        );
1293        if let Some(cat) = &provider.category {
1294            details.push_str(&format!("**Category**: {}\n", cat));
1295        }
1296        if provider.is_cli() {
1297            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1298            details.push_str(&format!(
1299                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
1300                tool.name, cmd
1301            ));
1302        } else if let Some(schema) = &tool.input_schema {
1303            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1304                let required: Vec<String> = schema
1305                    .get("required")
1306                    .and_then(|r| r.as_array())
1307                    .map(|arr| {
1308                        arr.iter()
1309                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
1310                            .collect()
1311                    })
1312                    .unwrap_or_default();
1313                details.push_str("\n**Parameters**:\n");
1314                for (key, val) in props {
1315                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1316                    let desc = val
1317                        .get("description")
1318                        .and_then(|d| d.as_str())
1319                        .unwrap_or("");
1320                    let req = if required.contains(key) {
1321                        " **(required)**"
1322                    } else {
1323                        ""
1324                    };
1325                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1326                }
1327            }
1328        }
1329
1330        let prompt = format!(
1331            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1332            ## Tool Details\n{}\n\n{}\n\n\
1333            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1334            tool.name, details, skills_section
1335        );
1336        return Some(prompt);
1337    }
1338
1339    // Check if scope_name is a provider
1340    if registry.has_provider(scope_name) {
1341        let tools = registry.tools_by_provider(scope_name);
1342        if tools.is_empty() {
1343            return None;
1344        }
1345        let tools_context = build_tool_context(&tools);
1346        let prompt = format!(
1347            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1348            ## Tools in provider `{}`\n{}\n\n{}\n\n\
1349            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1350            scope_name, scope_name, tools_context, skills_section
1351        );
1352        return Some(prompt);
1353    }
1354
1355    None
1356}