Skip to main content

ati/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::State,
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34/// Shared state for the proxy server.
35pub struct ProxyState {
36    pub registry: ManifestRegistry,
37    pub skill_registry: SkillRegistry,
38    pub keyring: Keyring,
39    /// JWT validation config (None = auth disabled / dev mode).
40    pub jwt_config: Option<JwtConfig>,
41    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
42    pub jwks_json: Option<Value>,
43    /// Shared cache for dynamically generated auth credentials.
44    pub auth_cache: AuthCache,
45}
46
47// --- Request/Response types ---
48
49#[derive(Debug, Deserialize)]
50pub struct CallRequest {
51    pub tool_name: String,
52    /// Tool arguments — accepts a JSON object (key-value pairs) for HTTP/MCP/OpenAPI tools,
53    /// or a JSON array of strings / a single string for CLI tools.
54    /// The proxy auto-detects the handler type and routes accordingly.
55    #[serde(default = "default_args")]
56    pub args: Value,
57    /// Deprecated: use `args` with an array value instead.
58    /// Kept for backward compatibility — if present, takes precedence for CLI tools.
59    #[serde(default)]
60    pub raw_args: Option<Vec<String>>,
61}
62
63fn default_args() -> Value {
64    Value::Object(serde_json::Map::new())
65}
66
67impl CallRequest {
68    /// Extract args as a HashMap for HTTP/MCP/OpenAPI tools.
69    /// If `args` is a JSON object, returns its entries.
70    /// If `args` is something else (array, string), returns an empty map.
71    fn args_as_map(&self) -> HashMap<String, Value> {
72        match &self.args {
73            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
74            _ => HashMap::new(),
75        }
76    }
77
78    /// Extract positional args for CLI tools.
79    /// Priority: explicit `raw_args` field > `args` array > `args` string > `args._positional` > empty.
80    fn args_as_positional(&self) -> Vec<String> {
81        // Backward compat: explicit raw_args wins
82        if let Some(ref raw) = self.raw_args {
83            return raw.clone();
84        }
85        match &self.args {
86            // ["pr", "list", "--repo", "X"]
87            Value::Array(arr) => arr
88                .iter()
89                .map(|v| match v {
90                    Value::String(s) => s.clone(),
91                    other => other.to_string(),
92                })
93                .collect(),
94            // "pr list --repo X"
95            Value::String(s) => s.split_whitespace().map(String::from).collect(),
96            // {"_positional": ["pr", "list"]} or {"--key": "value"} converted to CLI flags
97            Value::Object(map) => {
98                if let Some(Value::Array(pos)) = map.get("_positional") {
99                    return pos
100                        .iter()
101                        .map(|v| match v {
102                            Value::String(s) => s.clone(),
103                            other => other.to_string(),
104                        })
105                        .collect();
106                }
107                // Convert map entries to --key value pairs
108                let mut result = Vec::new();
109                for (k, v) in map {
110                    result.push(format!("--{k}"));
111                    match v {
112                        Value::String(s) => result.push(s.clone()),
113                        Value::Bool(true) => {} // flag, no value needed
114                        other => result.push(other.to_string()),
115                    }
116                }
117                result
118            }
119            _ => Vec::new(),
120        }
121    }
122}
123
124#[derive(Debug, Serialize)]
125pub struct CallResponse {
126    pub result: Value,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub error: Option<String>,
129}
130
131#[derive(Debug, Deserialize)]
132pub struct HelpRequest {
133    pub query: String,
134    #[serde(default)]
135    pub tool: Option<String>,
136}
137
138#[derive(Debug, Serialize)]
139pub struct HelpResponse {
140    pub content: String,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub error: Option<String>,
143}
144
145#[derive(Debug, Serialize)]
146pub struct HealthResponse {
147    pub status: String,
148    pub version: String,
149    pub tools: usize,
150    pub providers: usize,
151    pub skills: usize,
152    pub auth: String,
153}
154
155// --- Skill endpoint types ---
156
157#[derive(Debug, Deserialize)]
158pub struct SkillsQuery {
159    #[serde(default)]
160    pub category: Option<String>,
161    #[serde(default)]
162    pub provider: Option<String>,
163    #[serde(default)]
164    pub tool: Option<String>,
165    #[serde(default)]
166    pub search: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct SkillDetailQuery {
171    #[serde(default)]
172    pub meta: Option<bool>,
173    #[serde(default)]
174    pub refs: Option<bool>,
175}
176
177#[derive(Debug, Deserialize)]
178pub struct SkillResolveRequest {
179    pub scopes: Vec<String>,
180}
181
182// --- Handlers ---
183
184async fn handle_call(
185    State(state): State<Arc<ProxyState>>,
186    req: HttpRequest<Body>,
187) -> impl IntoResponse {
188    // Extract JWT claims from request extensions (set by auth middleware)
189    let claims = req.extensions().get::<TokenClaims>().cloned();
190
191    // Parse request body
192    let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
193        Ok(b) => b,
194        Err(e) => {
195            return (
196                StatusCode::BAD_REQUEST,
197                Json(CallResponse {
198                    result: Value::Null,
199                    error: Some(format!("Failed to read request body: {e}")),
200                }),
201            );
202        }
203    };
204
205    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
206        Ok(r) => r,
207        Err(e) => {
208            return (
209                StatusCode::UNPROCESSABLE_ENTITY,
210                Json(CallResponse {
211                    result: Value::Null,
212                    error: Some(format!("Invalid request: {e}")),
213                }),
214            );
215        }
216    };
217
218    tracing::debug!(
219        tool = %call_req.tool_name,
220        args = ?call_req.args,
221        "POST /call"
222    );
223
224    // Look up tool in registry
225    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
226        Some(pt) => pt,
227        None => {
228            return (
229                StatusCode::NOT_FOUND,
230                Json(CallResponse {
231                    result: Value::Null,
232                    error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
233                }),
234            );
235        }
236    };
237
238    // Scope enforcement from JWT claims
239    if let Some(tool_scope) = &tool.scope {
240        let scopes = match &claims {
241            Some(c) => ScopeConfig::from_jwt(c),
242            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
243            None => {
244                return (
245                    StatusCode::FORBIDDEN,
246                    Json(CallResponse {
247                        result: Value::Null,
248                        error: Some("Authentication required — no JWT provided".into()),
249                    }),
250                );
251            }
252        };
253
254        if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
255            return (
256                StatusCode::FORBIDDEN,
257                Json(CallResponse {
258                    result: Value::Null,
259                    error: Some(format!("Access denied: {e}")),
260                }),
261            );
262        }
263    }
264
265    // Rate limit check
266    {
267        let scopes = match &claims {
268            Some(c) => ScopeConfig::from_jwt(c),
269            None => ScopeConfig::unrestricted(),
270        };
271        if let Some(ref rate_config) = scopes.rate_config {
272            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
273                return (
274                    StatusCode::TOO_MANY_REQUESTS,
275                    Json(CallResponse {
276                        result: Value::Null,
277                        error: Some(format!("{e}")),
278                    }),
279                );
280            }
281        }
282    }
283
284    // Build auth generator context from JWT claims
285    let gen_ctx = GenContext {
286        jwt_sub: claims
287            .as_ref()
288            .map(|c| c.sub.clone())
289            .unwrap_or_else(|| "dev".into()),
290        jwt_scope: claims
291            .as_ref()
292            .map(|c| c.scope.clone())
293            .unwrap_or_else(|| "*".into()),
294        tool_name: call_req.tool_name.clone(),
295        timestamp: crate::core::jwt::now_secs(),
296    };
297
298    // Execute tool call — dispatch based on handler type, with timing for audit
299    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
300    let start = std::time::Instant::now();
301
302    let response = match provider.handler.as_str() {
303        "mcp" => {
304            let args_map = call_req.args_as_map();
305            match mcp_client::execute_with_gen(
306                provider,
307                &call_req.tool_name,
308                &args_map,
309                &state.keyring,
310                Some(&gen_ctx),
311                Some(&state.auth_cache),
312            )
313            .await
314            {
315                Ok(result) => (
316                    StatusCode::OK,
317                    Json(CallResponse {
318                        result,
319                        error: None,
320                    }),
321                ),
322                Err(e) => (
323                    StatusCode::BAD_GATEWAY,
324                    Json(CallResponse {
325                        result: Value::Null,
326                        error: Some(format!("MCP error: {e}")),
327                    }),
328                ),
329            }
330        }
331        "cli" => {
332            let positional = call_req.args_as_positional();
333            match crate::core::cli_executor::execute_with_gen(
334                provider,
335                &positional,
336                &state.keyring,
337                Some(&gen_ctx),
338                Some(&state.auth_cache),
339            )
340            .await
341            {
342                Ok(result) => (
343                    StatusCode::OK,
344                    Json(CallResponse {
345                        result,
346                        error: None,
347                    }),
348                ),
349                Err(e) => (
350                    StatusCode::BAD_GATEWAY,
351                    Json(CallResponse {
352                        result: Value::Null,
353                        error: Some(format!("CLI error: {e}")),
354                    }),
355                ),
356            }
357        }
358        _ => {
359            let args_map = call_req.args_as_map();
360            let raw_response = match match provider.handler.as_str() {
361                "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
362                _ => {
363                    http::execute_tool_with_gen(
364                        provider,
365                        tool,
366                        &args_map,
367                        &state.keyring,
368                        Some(&gen_ctx),
369                        Some(&state.auth_cache),
370                    )
371                    .await
372                }
373            } {
374                Ok(resp) => resp,
375                Err(e) => {
376                    let duration = start.elapsed();
377                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
378                    return (
379                        StatusCode::BAD_GATEWAY,
380                        Json(CallResponse {
381                            result: Value::Null,
382                            error: Some(format!("Upstream API error: {e}")),
383                        }),
384                    );
385                }
386            };
387
388            let processed = match response::process_response(&raw_response, tool.response.as_ref())
389            {
390                Ok(p) => p,
391                Err(e) => {
392                    let duration = start.elapsed();
393                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
394                    return (
395                        StatusCode::INTERNAL_SERVER_ERROR,
396                        Json(CallResponse {
397                            result: raw_response,
398                            error: Some(format!("Response processing error: {e}")),
399                        }),
400                    );
401                }
402            };
403
404            (
405                StatusCode::OK,
406                Json(CallResponse {
407                    result: processed,
408                    error: None,
409                }),
410            )
411        }
412    };
413
414    let duration = start.elapsed();
415    let error_msg = response.1.error.as_deref();
416    write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
417
418    response
419}
420
421async fn handle_help(
422    State(state): State<Arc<ProxyState>>,
423    Json(req): Json<HelpRequest>,
424) -> impl IntoResponse {
425    tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
426
427    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
428        Some(pt) => pt,
429        None => {
430            return (
431                StatusCode::SERVICE_UNAVAILABLE,
432                Json(HelpResponse {
433                    content: String::new(),
434                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
435                }),
436            );
437        }
438    };
439
440    let api_key = match llm_provider
441        .auth_key_name
442        .as_deref()
443        .and_then(|k| state.keyring.get(k))
444    {
445        Some(key) => key.to_string(),
446        None => {
447            return (
448                StatusCode::SERVICE_UNAVAILABLE,
449                Json(HelpResponse {
450                    content: String::new(),
451                    error: Some("LLM API key not found in keyring".into()),
452                }),
453            );
454        }
455    };
456
457    let scopes = ScopeConfig::unrestricted();
458    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
459    let skills_section = if resolved_skills.is_empty() {
460        String::new()
461    } else {
462        format!(
463            "## Available Skills (methodology guides)\n{}",
464            skill::build_skill_context(&resolved_skills)
465        )
466    };
467
468    // Build system prompt — scoped or unscoped
469    let system_prompt = if let Some(ref tool_name) = req.tool {
470        // Scoped mode: narrow tools to the specified tool or provider
471        match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
472            Some(prompt) => prompt,
473            None => {
474                // Fall back to unscoped if tool/provider not found
475                tracing::debug!(scope = %tool_name, "scope not found, falling back to unscoped");
476                let all_tools = state.registry.list_public_tools();
477                let tools_context = build_tool_context(&all_tools);
478                HELP_SYSTEM_PROMPT
479                    .replace("{tools}", &tools_context)
480                    .replace("{skills_section}", &skills_section)
481            }
482        }
483    } else {
484        let all_tools = state.registry.list_public_tools();
485        let tools_context = build_tool_context(&all_tools);
486        HELP_SYSTEM_PROMPT
487            .replace("{tools}", &tools_context)
488            .replace("{skills_section}", &skills_section)
489    };
490
491    let request_body = serde_json::json!({
492        "model": "zai-glm-4.7",
493        "messages": [
494            {"role": "system", "content": system_prompt},
495            {"role": "user", "content": req.query}
496        ],
497        "max_completion_tokens": 1536,
498        "temperature": 0.3
499    });
500
501    let client = reqwest::Client::new();
502    let url = format!(
503        "{}{}",
504        llm_provider.base_url.trim_end_matches('/'),
505        llm_tool.endpoint
506    );
507
508    let response = match client
509        .post(&url)
510        .bearer_auth(&api_key)
511        .json(&request_body)
512        .send()
513        .await
514    {
515        Ok(r) => r,
516        Err(e) => {
517            return (
518                StatusCode::BAD_GATEWAY,
519                Json(HelpResponse {
520                    content: String::new(),
521                    error: Some(format!("LLM request failed: {e}")),
522                }),
523            );
524        }
525    };
526
527    if !response.status().is_success() {
528        let status = response.status();
529        let body = response.text().await.unwrap_or_default();
530        return (
531            StatusCode::BAD_GATEWAY,
532            Json(HelpResponse {
533                content: String::new(),
534                error: Some(format!("LLM API error ({status}): {body}")),
535            }),
536        );
537    }
538
539    let body: Value = match response.json().await {
540        Ok(b) => b,
541        Err(e) => {
542            return (
543                StatusCode::INTERNAL_SERVER_ERROR,
544                Json(HelpResponse {
545                    content: String::new(),
546                    error: Some(format!("Failed to parse LLM response: {e}")),
547                }),
548            );
549        }
550    };
551
552    let content = body
553        .pointer("/choices/0/message/content")
554        .and_then(|c| c.as_str())
555        .unwrap_or("No response from LLM")
556        .to_string();
557
558    (
559        StatusCode::OK,
560        Json(HelpResponse {
561            content,
562            error: None,
563        }),
564    )
565}
566
567async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
568    let auth = if state.jwt_config.is_some() {
569        "jwt"
570    } else {
571        "disabled"
572    };
573
574    Json(HealthResponse {
575        status: "ok".into(),
576        version: env!("CARGO_PKG_VERSION").into(),
577        tools: state.registry.list_public_tools().len(),
578        providers: state.registry.list_providers().len(),
579        skills: state.skill_registry.skill_count(),
580        auth: auth.into(),
581    })
582}
583
584/// GET /.well-known/jwks.json — serves the public key for JWT validation.
585async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
586    match &state.jwks_json {
587        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
588        None => (
589            StatusCode::NOT_FOUND,
590            Json(serde_json::json!({"error": "JWKS not configured"})),
591        ),
592    }
593}
594
595// ---------------------------------------------------------------------------
596// POST /mcp — MCP JSON-RPC proxy endpoint
597// ---------------------------------------------------------------------------
598
599async fn handle_mcp(
600    State(state): State<Arc<ProxyState>>,
601    Json(msg): Json<Value>,
602) -> impl IntoResponse {
603    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
604    let id = msg.get("id").cloned();
605
606    tracing::debug!(%method, "POST /mcp");
607
608    match method {
609        "initialize" => {
610            let result = serde_json::json!({
611                "protocolVersion": "2025-03-26",
612                "capabilities": {
613                    "tools": { "listChanged": false }
614                },
615                "serverInfo": {
616                    "name": "ati-proxy",
617                    "version": env!("CARGO_PKG_VERSION")
618                }
619            });
620            jsonrpc_success(id, result)
621        }
622
623        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
624
625        "tools/list" => {
626            let all_tools = state.registry.list_public_tools();
627            let mcp_tools: Vec<Value> = all_tools
628                .iter()
629                .map(|(_provider, tool)| {
630                    serde_json::json!({
631                        "name": tool.name,
632                        "description": tool.description,
633                        "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
634                            "type": "object",
635                            "properties": {}
636                        }))
637                    })
638                })
639                .collect();
640
641            let result = serde_json::json!({
642                "tools": mcp_tools,
643            });
644            jsonrpc_success(id, result)
645        }
646
647        "tools/call" => {
648            let params = msg.get("params").cloned().unwrap_or(Value::Null);
649            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
650            let arguments: HashMap<String, Value> = params
651                .get("arguments")
652                .and_then(|a| serde_json::from_value(a.clone()).ok())
653                .unwrap_or_default();
654
655            if tool_name.is_empty() {
656                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
657            }
658
659            let (provider, _tool) = match state.registry.get_tool(tool_name) {
660                Some(pt) => pt,
661                None => {
662                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
663                }
664            };
665
666            tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
667
668            let mcp_gen_ctx = GenContext {
669                jwt_sub: "dev".into(),
670                jwt_scope: "*".into(),
671                tool_name: tool_name.to_string(),
672                timestamp: crate::core::jwt::now_secs(),
673            };
674
675            let result = if provider.is_mcp() {
676                mcp_client::execute_with_gen(
677                    provider,
678                    tool_name,
679                    &arguments,
680                    &state.keyring,
681                    Some(&mcp_gen_ctx),
682                    Some(&state.auth_cache),
683                )
684                .await
685            } else if provider.is_cli() {
686                // Convert arguments map to CLI-style args for MCP passthrough
687                let raw: Vec<String> = arguments
688                    .iter()
689                    .flat_map(|(k, v)| {
690                        let val = match v {
691                            Value::String(s) => s.clone(),
692                            other => other.to_string(),
693                        };
694                        vec![format!("--{k}"), val]
695                    })
696                    .collect();
697                crate::core::cli_executor::execute_with_gen(
698                    provider,
699                    &raw,
700                    &state.keyring,
701                    Some(&mcp_gen_ctx),
702                    Some(&state.auth_cache),
703                )
704                .await
705                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
706            } else {
707                match match provider.handler.as_str() {
708                    "xai" => {
709                        xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
710                    }
711                    _ => {
712                        http::execute_tool_with_gen(
713                            provider,
714                            _tool,
715                            &arguments,
716                            &state.keyring,
717                            Some(&mcp_gen_ctx),
718                            Some(&state.auth_cache),
719                        )
720                        .await
721                    }
722                } {
723                    Ok(val) => Ok(val),
724                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
725                }
726            };
727
728            match result {
729                Ok(value) => {
730                    let text = match &value {
731                        Value::String(s) => s.clone(),
732                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
733                    };
734                    let mcp_result = serde_json::json!({
735                        "content": [{"type": "text", "text": text}],
736                        "isError": false,
737                    });
738                    jsonrpc_success(id, mcp_result)
739                }
740                Err(e) => {
741                    let mcp_result = serde_json::json!({
742                        "content": [{"type": "text", "text": format!("Error: {e}")}],
743                        "isError": true,
744                    });
745                    jsonrpc_success(id, mcp_result)
746                }
747            }
748        }
749
750        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
751    }
752}
753
754fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
755    (
756        StatusCode::OK,
757        Json(serde_json::json!({
758            "jsonrpc": "2.0",
759            "id": id,
760            "result": result,
761        })),
762    )
763}
764
765fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
766    (
767        StatusCode::OK,
768        Json(serde_json::json!({
769            "jsonrpc": "2.0",
770            "id": id,
771            "error": {
772                "code": code,
773                "message": message,
774            }
775        })),
776    )
777}
778
779// ---------------------------------------------------------------------------
780// Skill endpoints
781// ---------------------------------------------------------------------------
782
783async fn handle_skills_list(
784    State(state): State<Arc<ProxyState>>,
785    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
786) -> impl IntoResponse {
787    tracing::debug!(
788        category = ?query.category,
789        provider = ?query.provider,
790        tool = ?query.tool,
791        search = ?query.search,
792        "GET /skills"
793    );
794
795    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
796        state.skill_registry.search(search_query)
797    } else if let Some(cat) = &query.category {
798        state.skill_registry.skills_for_category(cat)
799    } else if let Some(prov) = &query.provider {
800        state.skill_registry.skills_for_provider(prov)
801    } else if let Some(t) = &query.tool {
802        state.skill_registry.skills_for_tool(t)
803    } else {
804        state.skill_registry.list_skills().iter().collect()
805    };
806
807    let json: Vec<Value> = skills
808        .iter()
809        .map(|s| {
810            serde_json::json!({
811                "name": s.name,
812                "version": s.version,
813                "description": s.description,
814                "tools": s.tools,
815                "providers": s.providers,
816                "categories": s.categories,
817                "hint": s.hint,
818            })
819        })
820        .collect();
821
822    (StatusCode::OK, Json(Value::Array(json)))
823}
824
825async fn handle_skill_detail(
826    State(state): State<Arc<ProxyState>>,
827    axum::extract::Path(name): axum::extract::Path<String>,
828    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
829) -> impl IntoResponse {
830    tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
831
832    let skill_meta = match state.skill_registry.get_skill(&name) {
833        Some(s) => s,
834        None => {
835            return (
836                StatusCode::NOT_FOUND,
837                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
838            );
839        }
840    };
841
842    if query.meta.unwrap_or(false) {
843        return (
844            StatusCode::OK,
845            Json(serde_json::json!({
846                "name": skill_meta.name,
847                "version": skill_meta.version,
848                "description": skill_meta.description,
849                "author": skill_meta.author,
850                "tools": skill_meta.tools,
851                "providers": skill_meta.providers,
852                "categories": skill_meta.categories,
853                "keywords": skill_meta.keywords,
854                "hint": skill_meta.hint,
855                "depends_on": skill_meta.depends_on,
856                "suggests": skill_meta.suggests,
857                "license": skill_meta.license,
858                "compatibility": skill_meta.compatibility,
859                "allowed_tools": skill_meta.allowed_tools,
860                "format": skill_meta.format,
861            })),
862        );
863    }
864
865    let content = match state.skill_registry.read_content(&name) {
866        Ok(c) => c,
867        Err(e) => {
868            return (
869                StatusCode::INTERNAL_SERVER_ERROR,
870                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
871            );
872        }
873    };
874
875    let mut response = serde_json::json!({
876        "name": skill_meta.name,
877        "version": skill_meta.version,
878        "description": skill_meta.description,
879        "content": content,
880    });
881
882    if query.refs.unwrap_or(false) {
883        if let Ok(refs) = state.skill_registry.list_references(&name) {
884            response["references"] = serde_json::json!(refs);
885        }
886    }
887
888    (StatusCode::OK, Json(response))
889}
890
891async fn handle_skills_resolve(
892    State(state): State<Arc<ProxyState>>,
893    Json(req): Json<SkillResolveRequest>,
894) -> impl IntoResponse {
895    tracing::debug!(scopes = ?req.scopes, "POST /skills/resolve");
896
897    let scopes = ScopeConfig {
898        scopes: req.scopes,
899        sub: String::new(),
900        expires_at: 0,
901        rate_config: None,
902    };
903
904    let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
905
906    let json: Vec<Value> = resolved
907        .iter()
908        .map(|s| {
909            serde_json::json!({
910                "name": s.name,
911                "version": s.version,
912                "description": s.description,
913                "tools": s.tools,
914                "providers": s.providers,
915                "categories": s.categories,
916            })
917        })
918        .collect();
919
920    (StatusCode::OK, Json(Value::Array(json)))
921}
922
923// --- Auth middleware ---
924
925/// JWT authentication middleware.
926///
927/// - /health and /.well-known/jwks.json → skip auth
928/// - JWT configured → validate Bearer token, attach claims to request extensions
929/// - No JWT configured → allow all (dev mode)
930async fn auth_middleware(
931    State(state): State<Arc<ProxyState>>,
932    mut req: HttpRequest<Body>,
933    next: Next,
934) -> Result<Response, StatusCode> {
935    let path = req.uri().path();
936
937    // Skip auth for public endpoints
938    if path == "/health" || path == "/.well-known/jwks.json" {
939        return Ok(next.run(req).await);
940    }
941
942    // If no JWT configured, allow all (dev mode)
943    let jwt_config = match &state.jwt_config {
944        Some(c) => c,
945        None => return Ok(next.run(req).await),
946    };
947
948    // Extract Authorization: Bearer <token>
949    let auth_header = req
950        .headers()
951        .get("authorization")
952        .and_then(|v| v.to_str().ok());
953
954    let token = match auth_header {
955        Some(header) if header.starts_with("Bearer ") => &header[7..],
956        _ => return Err(StatusCode::UNAUTHORIZED),
957    };
958
959    // Validate JWT
960    match jwt::validate(token, jwt_config) {
961        Ok(claims) => {
962            tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
963            req.extensions_mut().insert(claims);
964            Ok(next.run(req).await)
965        }
966        Err(e) => {
967            tracing::debug!(error = %e, "JWT validation failed");
968            Err(StatusCode::UNAUTHORIZED)
969        }
970    }
971}
972
973// --- Router builder ---
974
975/// Build the axum Router from a pre-constructed ProxyState.
976pub fn build_router(state: Arc<ProxyState>) -> Router {
977    Router::new()
978        .route("/call", post(handle_call))
979        .route("/help", post(handle_help))
980        .route("/mcp", post(handle_mcp))
981        .route("/skills", get(handle_skills_list))
982        .route("/skills/resolve", post(handle_skills_resolve))
983        .route("/skills/{name}", get(handle_skill_detail))
984        .route("/health", get(handle_health))
985        .route("/.well-known/jwks.json", get(handle_jwks))
986        .layer(middleware::from_fn_with_state(
987            state.clone(),
988            auth_middleware,
989        ))
990        .with_state(state)
991}
992
993// --- Server startup ---
994
995/// Start the proxy server.
996pub async fn run(
997    port: u16,
998    bind_addr: Option<String>,
999    ati_dir: PathBuf,
1000    _verbose: bool,
1001    env_keys: bool,
1002) -> Result<(), Box<dyn std::error::Error>> {
1003    // Load manifests
1004    let manifests_dir = ati_dir.join("manifests");
1005    let registry = ManifestRegistry::load(&manifests_dir)?;
1006
1007    let tool_count = registry.list_public_tools().len();
1008    let provider_count = registry.list_providers().len();
1009
1010    // Load keyring
1011    let keyring_source;
1012    let keyring = if env_keys {
1013        // --env-keys: scan ATI_KEY_* environment variables
1014        let kr = Keyring::from_env();
1015        let key_names = kr.key_names();
1016        tracing::info!(
1017            count = key_names.len(),
1018            "loaded API keys from ATI_KEY_* env vars"
1019        );
1020        for name in &key_names {
1021            tracing::debug!(key = %name, "env key loaded");
1022        }
1023        keyring_source = "env-vars (ATI_KEY_*)";
1024        kr
1025    } else {
1026        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
1027        let keyring_path = ati_dir.join("keyring.enc");
1028        if keyring_path.exists() {
1029            if let Ok(kr) = Keyring::load(&keyring_path) {
1030                keyring_source = "keyring.enc (sealed key)";
1031                kr
1032            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1033                keyring_source = "keyring.enc (persistent key)";
1034                kr
1035            } else {
1036                tracing::warn!("keyring.enc exists but could not be decrypted");
1037                keyring_source = "empty (decryption failed)";
1038                Keyring::empty()
1039            }
1040        } else {
1041            let creds_path = ati_dir.join("credentials");
1042            if creds_path.exists() {
1043                match Keyring::load_credentials(&creds_path) {
1044                    Ok(kr) => {
1045                        keyring_source = "credentials (plaintext)";
1046                        kr
1047                    }
1048                    Err(e) => {
1049                        tracing::warn!(error = %e, "failed to load credentials");
1050                        keyring_source = "empty (credentials error)";
1051                        Keyring::empty()
1052                    }
1053                }
1054            } else {
1055                tracing::warn!("no keyring.enc or credentials found — running without API keys");
1056                tracing::warn!("tools requiring authentication will fail");
1057                keyring_source = "empty (no auth)";
1058                Keyring::empty()
1059            }
1060        }
1061    };
1062
1063    // Log MCP and OpenAPI providers
1064    let mcp_providers: Vec<(String, String)> = registry
1065        .list_mcp_providers()
1066        .iter()
1067        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1068        .collect();
1069    let mcp_count = mcp_providers.len();
1070    let openapi_providers: Vec<String> = registry
1071        .list_openapi_providers()
1072        .iter()
1073        .map(|p| p.name.clone())
1074        .collect();
1075    let openapi_count = openapi_providers.len();
1076
1077    // Load skill registry
1078    let skills_dir = ati_dir.join("skills");
1079    let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1080        tracing::warn!(error = %e, "failed to load skills");
1081        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1082    });
1083    let skill_count = skill_registry.skill_count();
1084
1085    // Load JWT config from environment
1086    let jwt_config = match jwt::config_from_env() {
1087        Ok(config) => config,
1088        Err(e) => {
1089            tracing::warn!(error = %e, "JWT config error");
1090            None
1091        }
1092    };
1093
1094    let auth_status = if jwt_config.is_some() {
1095        "JWT enabled"
1096    } else {
1097        "DISABLED (no JWT keys configured)"
1098    };
1099
1100    // Build JWKS for the endpoint
1101    let jwks_json = jwt_config.as_ref().and_then(|config| {
1102        config
1103            .public_key_pem
1104            .as_ref()
1105            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1106    });
1107
1108    let state = Arc::new(ProxyState {
1109        registry,
1110        skill_registry,
1111        keyring,
1112        jwt_config,
1113        jwks_json,
1114        auth_cache: AuthCache::new(),
1115    });
1116
1117    let app = build_router(state);
1118
1119    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1120        format!("{bind}:{port}").parse()?
1121    } else {
1122        SocketAddr::from(([127, 0, 0, 1], port))
1123    };
1124
1125    tracing::info!(
1126        version = env!("CARGO_PKG_VERSION"),
1127        %addr,
1128        auth = auth_status,
1129        ati_dir = %ati_dir.display(),
1130        tools = tool_count,
1131        providers = provider_count,
1132        mcp = mcp_count,
1133        openapi = openapi_count,
1134        skills = skill_count,
1135        keyring = keyring_source,
1136        "ATI proxy server starting"
1137    );
1138    for (name, transport) in &mcp_providers {
1139        tracing::info!(provider = %name, transport = %transport, "MCP provider");
1140    }
1141    for name in &openapi_providers {
1142        tracing::info!(provider = %name, "OpenAPI provider");
1143    }
1144
1145    let listener = tokio::net::TcpListener::bind(addr).await?;
1146    axum::serve(listener, app).await?;
1147
1148    Ok(())
1149}
1150
1151/// Write an audit entry from the proxy server. Failures are silently ignored.
1152fn write_proxy_audit(
1153    call_req: &CallRequest,
1154    agent_sub: &str,
1155    duration: std::time::Duration,
1156    error: Option<&str>,
1157) {
1158    let entry = crate::core::audit::AuditEntry {
1159        ts: chrono::Utc::now().to_rfc3339(),
1160        tool: call_req.tool_name.clone(),
1161        args: crate::core::audit::sanitize_args(&call_req.args),
1162        status: if error.is_some() {
1163            crate::core::audit::AuditStatus::Error
1164        } else {
1165            crate::core::audit::AuditStatus::Ok
1166        },
1167        duration_ms: duration.as_millis() as u64,
1168        agent_sub: agent_sub.to_string(),
1169        error: error.map(|s| s.to_string()),
1170        exit_code: None,
1171    };
1172    let _ = crate::core::audit::append(&entry);
1173}
1174
1175// --- Helpers ---
1176
1177const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1178
1179## Available Tools
1180{tools}
1181
1182{skills_section}
1183
1184Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1185
1186- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1187- If multiple steps are needed, walk through them briefly in order
1188- Mention important gotchas or parameter choices that matter
1189- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1190
1191Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1192
1193fn build_tool_context(
1194    tools: &[(
1195        &crate::core::manifest::Provider,
1196        &crate::core::manifest::Tool,
1197    )],
1198) -> String {
1199    let mut summaries = Vec::new();
1200    for (provider, tool) in tools {
1201        let mut summary = if let Some(cat) = &provider.category {
1202            format!(
1203                "- **{}** (provider: {}, category: {}): {}",
1204                tool.name, provider.name, cat, tool.description
1205            )
1206        } else {
1207            format!(
1208                "- **{}** (provider: {}): {}",
1209                tool.name, provider.name, tool.description
1210            )
1211        };
1212        if !tool.tags.is_empty() {
1213            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
1214        }
1215        // CLI tools: show passthrough usage
1216        if provider.is_cli() && tool.input_schema.is_none() {
1217            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1218            summary.push_str(&format!(
1219                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
1220                tool.name, cmd
1221            ));
1222        } else if let Some(schema) = &tool.input_schema {
1223            if let Some(props) = schema.get("properties") {
1224                if let Some(obj) = props.as_object() {
1225                    let params: Vec<String> = obj
1226                        .iter()
1227                        .filter(|(_, v)| {
1228                            v.get("x-ati-param-location").is_none()
1229                                || v.get("description").is_some()
1230                        })
1231                        .map(|(k, v)| {
1232                            let type_str =
1233                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1234                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1235                            format!("    --{k} ({type_str}): {desc}")
1236                        })
1237                        .collect();
1238                    if !params.is_empty() {
1239                        summary.push_str("\n  Parameters:\n");
1240                        summary.push_str(&params.join("\n"));
1241                    }
1242                }
1243            }
1244        }
1245        summaries.push(summary);
1246    }
1247    summaries.join("\n\n")
1248}
1249
1250/// Build a scoped system prompt for a specific tool or provider.
1251///
1252/// Returns None if the scope_name doesn't match any tool or provider.
1253fn build_scoped_prompt(
1254    scope_name: &str,
1255    registry: &ManifestRegistry,
1256    skills_section: &str,
1257) -> Option<String> {
1258    // Check if scope_name is a tool
1259    if let Some((provider, tool)) = registry.get_tool(scope_name) {
1260        let mut details = format!(
1261            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1262            tool.name, provider.name, provider.handler, tool.description
1263        );
1264        if let Some(cat) = &provider.category {
1265            details.push_str(&format!("**Category**: {}\n", cat));
1266        }
1267        if provider.is_cli() {
1268            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1269            details.push_str(&format!(
1270                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
1271                tool.name, cmd
1272            ));
1273        } else if let Some(schema) = &tool.input_schema {
1274            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1275                let required: Vec<String> = schema
1276                    .get("required")
1277                    .and_then(|r| r.as_array())
1278                    .map(|arr| {
1279                        arr.iter()
1280                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
1281                            .collect()
1282                    })
1283                    .unwrap_or_default();
1284                details.push_str("\n**Parameters**:\n");
1285                for (key, val) in props {
1286                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1287                    let desc = val
1288                        .get("description")
1289                        .and_then(|d| d.as_str())
1290                        .unwrap_or("");
1291                    let req = if required.contains(key) {
1292                        " **(required)**"
1293                    } else {
1294                        ""
1295                    };
1296                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1297                }
1298            }
1299        }
1300
1301        let prompt = format!(
1302            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1303            ## Tool Details\n{}\n\n{}\n\n\
1304            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1305            tool.name, details, skills_section
1306        );
1307        return Some(prompt);
1308    }
1309
1310    // Check if scope_name is a provider
1311    if registry.has_provider(scope_name) {
1312        let tools = registry.tools_by_provider(scope_name);
1313        if tools.is_empty() {
1314            return None;
1315        }
1316        let tools_context = build_tool_context(&tools);
1317        let prompt = format!(
1318            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1319            ## Tools in provider `{}`\n{}\n\n{}\n\n\
1320            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1321            scope_name, scope_name, tools_context, skills_section
1322        );
1323        return Some(prompt);
1324    }
1325
1326    None
1327}