Skip to main content

agent_tools_interface/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::State,
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34/// Shared state for the proxy server.
35pub struct ProxyState {
36    pub registry: ManifestRegistry,
37    pub skill_registry: SkillRegistry,
38    pub keyring: Keyring,
39    pub verbose: bool,
40    /// JWT validation config (None = auth disabled / dev mode).
41    pub jwt_config: Option<JwtConfig>,
42    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
43    pub jwks_json: Option<Value>,
44    /// Shared cache for dynamically generated auth credentials.
45    pub auth_cache: AuthCache,
46}
47
48// --- Request/Response types ---
49
50#[derive(Debug, Deserialize)]
51pub struct CallRequest {
52    pub tool_name: String,
53    pub args: HashMap<String, Value>,
54    #[serde(default)]
55    pub raw_args: Option<Vec<String>>,
56}
57
58#[derive(Debug, Serialize)]
59pub struct CallResponse {
60    pub result: Value,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub error: Option<String>,
63}
64
65#[derive(Debug, Deserialize)]
66pub struct HelpRequest {
67    pub query: String,
68    #[serde(default)]
69    pub tool: Option<String>,
70}
71
72#[derive(Debug, Serialize)]
73pub struct HelpResponse {
74    pub content: String,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub error: Option<String>,
77}
78
79#[derive(Debug, Serialize)]
80pub struct HealthResponse {
81    pub status: String,
82    pub version: String,
83    pub tools: usize,
84    pub providers: usize,
85    pub skills: usize,
86    pub auth: String,
87}
88
89// --- Skill endpoint types ---
90
91#[derive(Debug, Deserialize)]
92pub struct SkillsQuery {
93    #[serde(default)]
94    pub category: Option<String>,
95    #[serde(default)]
96    pub provider: Option<String>,
97    #[serde(default)]
98    pub tool: Option<String>,
99    #[serde(default)]
100    pub search: Option<String>,
101}
102
103#[derive(Debug, Deserialize)]
104pub struct SkillDetailQuery {
105    #[serde(default)]
106    pub meta: Option<bool>,
107    #[serde(default)]
108    pub refs: Option<bool>,
109}
110
111#[derive(Debug, Deserialize)]
112pub struct SkillResolveRequest {
113    pub scopes: Vec<String>,
114}
115
116// --- Handlers ---
117
118async fn handle_call(
119    State(state): State<Arc<ProxyState>>,
120    req: HttpRequest<Body>,
121) -> impl IntoResponse {
122    // Extract JWT claims from request extensions (set by auth middleware)
123    let claims = req.extensions().get::<TokenClaims>().cloned();
124
125    // Parse request body
126    let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
127        Ok(b) => b,
128        Err(e) => {
129            return (
130                StatusCode::BAD_REQUEST,
131                Json(CallResponse {
132                    result: Value::Null,
133                    error: Some(format!("Failed to read request body: {e}")),
134                }),
135            );
136        }
137    };
138
139    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
140        Ok(r) => r,
141        Err(e) => {
142            return (
143                StatusCode::UNPROCESSABLE_ENTITY,
144                Json(CallResponse {
145                    result: Value::Null,
146                    error: Some(format!("Invalid request: {e}")),
147                }),
148            );
149        }
150    };
151
152    if state.verbose {
153        eprintln!(
154            "[proxy] /call tool={} args={:?}",
155            call_req.tool_name, call_req.args
156        );
157    }
158
159    // Look up tool in registry
160    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
161        Some(pt) => pt,
162        None => {
163            return (
164                StatusCode::NOT_FOUND,
165                Json(CallResponse {
166                    result: Value::Null,
167                    error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
168                }),
169            );
170        }
171    };
172
173    // Scope enforcement from JWT claims
174    if let Some(tool_scope) = &tool.scope {
175        let scopes = match &claims {
176            Some(c) => ScopeConfig::from_jwt(c),
177            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
178            None => {
179                return (
180                    StatusCode::FORBIDDEN,
181                    Json(CallResponse {
182                        result: Value::Null,
183                        error: Some("Authentication required — no JWT provided".into()),
184                    }),
185                );
186            }
187        };
188
189        if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
190            return (
191                StatusCode::FORBIDDEN,
192                Json(CallResponse {
193                    result: Value::Null,
194                    error: Some(format!("Access denied: {e}")),
195                }),
196            );
197        }
198    }
199
200    // Rate limit check
201    {
202        let scopes = match &claims {
203            Some(c) => ScopeConfig::from_jwt(c),
204            None => ScopeConfig::unrestricted(),
205        };
206        if let Some(ref rate_config) = scopes.rate_config {
207            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
208                return (
209                    StatusCode::TOO_MANY_REQUESTS,
210                    Json(CallResponse {
211                        result: Value::Null,
212                        error: Some(format!("{e}")),
213                    }),
214                );
215            }
216        }
217    }
218
219    // Build auth generator context from JWT claims
220    let gen_ctx = GenContext {
221        jwt_sub: claims
222            .as_ref()
223            .map(|c| c.sub.clone())
224            .unwrap_or_else(|| "dev".into()),
225        jwt_scope: claims
226            .as_ref()
227            .map(|c| c.scope.clone())
228            .unwrap_or_else(|| "*".into()),
229        tool_name: call_req.tool_name.clone(),
230        timestamp: crate::core::jwt::now_secs(),
231    };
232
233    // Execute tool call — dispatch based on handler type, with timing for audit
234    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
235    let start = std::time::Instant::now();
236
237    let response = match provider.handler.as_str() {
238        "mcp" => {
239            match mcp_client::execute_with_gen(
240                provider,
241                &call_req.tool_name,
242                &call_req.args,
243                &state.keyring,
244                Some(&gen_ctx),
245                Some(&state.auth_cache),
246            )
247            .await
248            {
249                Ok(result) => (
250                    StatusCode::OK,
251                    Json(CallResponse {
252                        result,
253                        error: None,
254                    }),
255                ),
256                Err(e) => (
257                    StatusCode::BAD_GATEWAY,
258                    Json(CallResponse {
259                        result: Value::Null,
260                        error: Some(format!("MCP error: {e}")),
261                    }),
262                ),
263            }
264        }
265        "cli" => {
266            let raw = call_req.raw_args.as_deref().unwrap_or(&[]);
267            match crate::core::cli_executor::execute_with_gen(
268                provider,
269                raw,
270                &state.keyring,
271                Some(&gen_ctx),
272                Some(&state.auth_cache),
273            )
274            .await
275            {
276                Ok(result) => (
277                    StatusCode::OK,
278                    Json(CallResponse {
279                        result,
280                        error: None,
281                    }),
282                ),
283                Err(e) => (
284                    StatusCode::BAD_GATEWAY,
285                    Json(CallResponse {
286                        result: Value::Null,
287                        error: Some(format!("CLI error: {e}")),
288                    }),
289                ),
290            }
291        }
292        _ => {
293            let raw_response = match match provider.handler.as_str() {
294                "xai" => {
295                    xai::execute_xai_tool(provider, tool, &call_req.args, &state.keyring).await
296                }
297                _ => {
298                    http::execute_tool_with_gen(
299                        provider,
300                        tool,
301                        &call_req.args,
302                        &state.keyring,
303                        Some(&gen_ctx),
304                        Some(&state.auth_cache),
305                    )
306                    .await
307                }
308            } {
309                Ok(resp) => resp,
310                Err(e) => {
311                    let duration = start.elapsed();
312                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
313                    return (
314                        StatusCode::BAD_GATEWAY,
315                        Json(CallResponse {
316                            result: Value::Null,
317                            error: Some(format!("Upstream API error: {e}")),
318                        }),
319                    );
320                }
321            };
322
323            let processed = match response::process_response(&raw_response, tool.response.as_ref())
324            {
325                Ok(p) => p,
326                Err(e) => {
327                    let duration = start.elapsed();
328                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
329                    return (
330                        StatusCode::INTERNAL_SERVER_ERROR,
331                        Json(CallResponse {
332                            result: raw_response,
333                            error: Some(format!("Response processing error: {e}")),
334                        }),
335                    );
336                }
337            };
338
339            (
340                StatusCode::OK,
341                Json(CallResponse {
342                    result: processed,
343                    error: None,
344                }),
345            )
346        }
347    };
348
349    let duration = start.elapsed();
350    let error_msg = response.1.error.as_deref();
351    write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
352
353    response
354}
355
356async fn handle_help(
357    State(state): State<Arc<ProxyState>>,
358    Json(req): Json<HelpRequest>,
359) -> impl IntoResponse {
360    if state.verbose {
361        eprintln!("[proxy] /help query={} tool={:?}", req.query, req.tool);
362    }
363
364    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
365        Some(pt) => pt,
366        None => {
367            return (
368                StatusCode::SERVICE_UNAVAILABLE,
369                Json(HelpResponse {
370                    content: String::new(),
371                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
372                }),
373            );
374        }
375    };
376
377    let api_key = match llm_provider
378        .auth_key_name
379        .as_deref()
380        .and_then(|k| state.keyring.get(k))
381    {
382        Some(key) => key.to_string(),
383        None => {
384            return (
385                StatusCode::SERVICE_UNAVAILABLE,
386                Json(HelpResponse {
387                    content: String::new(),
388                    error: Some("LLM API key not found in keyring".into()),
389                }),
390            );
391        }
392    };
393
394    let scopes = ScopeConfig::unrestricted();
395    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
396    let skills_section = if resolved_skills.is_empty() {
397        String::new()
398    } else {
399        format!(
400            "## Available Skills (methodology guides)\n{}",
401            skill::build_skill_context(&resolved_skills)
402        )
403    };
404
405    // Build system prompt — scoped or unscoped
406    let system_prompt = if let Some(ref tool_name) = req.tool {
407        // Scoped mode: narrow tools to the specified tool or provider
408        match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
409            Some(prompt) => prompt,
410            None => {
411                // Fall back to unscoped if tool/provider not found
412                if state.verbose {
413                    eprintln!(
414                        "[proxy] /help scope '{}' not found, falling back to unscoped",
415                        tool_name
416                    );
417                }
418                let all_tools = state.registry.list_public_tools();
419                let tools_context = build_tool_context(&all_tools);
420                HELP_SYSTEM_PROMPT
421                    .replace("{tools}", &tools_context)
422                    .replace("{skills_section}", &skills_section)
423            }
424        }
425    } else {
426        let all_tools = state.registry.list_public_tools();
427        let tools_context = build_tool_context(&all_tools);
428        HELP_SYSTEM_PROMPT
429            .replace("{tools}", &tools_context)
430            .replace("{skills_section}", &skills_section)
431    };
432
433    let request_body = serde_json::json!({
434        "model": "zai-glm-4.7",
435        "messages": [
436            {"role": "system", "content": system_prompt},
437            {"role": "user", "content": req.query}
438        ],
439        "max_completion_tokens": 1536,
440        "temperature": 0.3
441    });
442
443    let client = reqwest::Client::new();
444    let url = format!(
445        "{}{}",
446        llm_provider.base_url.trim_end_matches('/'),
447        llm_tool.endpoint
448    );
449
450    let response = match client
451        .post(&url)
452        .bearer_auth(&api_key)
453        .json(&request_body)
454        .send()
455        .await
456    {
457        Ok(r) => r,
458        Err(e) => {
459            return (
460                StatusCode::BAD_GATEWAY,
461                Json(HelpResponse {
462                    content: String::new(),
463                    error: Some(format!("LLM request failed: {e}")),
464                }),
465            );
466        }
467    };
468
469    if !response.status().is_success() {
470        let status = response.status();
471        let body = response.text().await.unwrap_or_default();
472        return (
473            StatusCode::BAD_GATEWAY,
474            Json(HelpResponse {
475                content: String::new(),
476                error: Some(format!("LLM API error ({status}): {body}")),
477            }),
478        );
479    }
480
481    let body: Value = match response.json().await {
482        Ok(b) => b,
483        Err(e) => {
484            return (
485                StatusCode::INTERNAL_SERVER_ERROR,
486                Json(HelpResponse {
487                    content: String::new(),
488                    error: Some(format!("Failed to parse LLM response: {e}")),
489                }),
490            );
491        }
492    };
493
494    let content = body
495        .pointer("/choices/0/message/content")
496        .and_then(|c| c.as_str())
497        .unwrap_or("No response from LLM")
498        .to_string();
499
500    (
501        StatusCode::OK,
502        Json(HelpResponse {
503            content,
504            error: None,
505        }),
506    )
507}
508
509async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
510    let auth = if state.jwt_config.is_some() {
511        "jwt"
512    } else {
513        "disabled"
514    };
515
516    Json(HealthResponse {
517        status: "ok".into(),
518        version: env!("CARGO_PKG_VERSION").into(),
519        tools: state.registry.list_public_tools().len(),
520        providers: state.registry.list_providers().len(),
521        skills: state.skill_registry.skill_count(),
522        auth: auth.into(),
523    })
524}
525
526/// GET /.well-known/jwks.json — serves the public key for JWT validation.
527async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
528    match &state.jwks_json {
529        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
530        None => (
531            StatusCode::NOT_FOUND,
532            Json(serde_json::json!({"error": "JWKS not configured"})),
533        ),
534    }
535}
536
537// ---------------------------------------------------------------------------
538// POST /mcp — MCP JSON-RPC proxy endpoint
539// ---------------------------------------------------------------------------
540
541async fn handle_mcp(
542    State(state): State<Arc<ProxyState>>,
543    Json(msg): Json<Value>,
544) -> impl IntoResponse {
545    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
546    let id = msg.get("id").cloned();
547
548    if state.verbose {
549        eprintln!("[proxy] /mcp method={method}");
550    }
551
552    match method {
553        "initialize" => {
554            let result = serde_json::json!({
555                "protocolVersion": "2025-03-26",
556                "capabilities": {
557                    "tools": { "listChanged": false }
558                },
559                "serverInfo": {
560                    "name": "ati-proxy",
561                    "version": env!("CARGO_PKG_VERSION")
562                }
563            });
564            jsonrpc_success(id, result)
565        }
566
567        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
568
569        "tools/list" => {
570            let all_tools = state.registry.list_public_tools();
571            let mcp_tools: Vec<Value> = all_tools
572                .iter()
573                .map(|(_provider, tool)| {
574                    serde_json::json!({
575                        "name": tool.name,
576                        "description": tool.description,
577                        "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
578                            "type": "object",
579                            "properties": {}
580                        }))
581                    })
582                })
583                .collect();
584
585            let result = serde_json::json!({
586                "tools": mcp_tools,
587            });
588            jsonrpc_success(id, result)
589        }
590
591        "tools/call" => {
592            let params = msg.get("params").cloned().unwrap_or(Value::Null);
593            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
594            let arguments: HashMap<String, Value> = params
595                .get("arguments")
596                .and_then(|a| serde_json::from_value(a.clone()).ok())
597                .unwrap_or_default();
598
599            if tool_name.is_empty() {
600                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
601            }
602
603            let (provider, _tool) = match state.registry.get_tool(tool_name) {
604                Some(pt) => pt,
605                None => {
606                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
607                }
608            };
609
610            if state.verbose {
611                eprintln!(
612                    "[proxy] /mcp tools/call name={tool_name} provider={}",
613                    provider.name
614                );
615            }
616
617            let mcp_gen_ctx = GenContext {
618                jwt_sub: "dev".into(),
619                jwt_scope: "*".into(),
620                tool_name: tool_name.to_string(),
621                timestamp: crate::core::jwt::now_secs(),
622            };
623
624            let result = if provider.is_mcp() {
625                mcp_client::execute_with_gen(
626                    provider,
627                    tool_name,
628                    &arguments,
629                    &state.keyring,
630                    Some(&mcp_gen_ctx),
631                    Some(&state.auth_cache),
632                )
633                .await
634            } else if provider.is_cli() {
635                // Convert arguments map to CLI-style args for MCP passthrough
636                let raw: Vec<String> = arguments
637                    .iter()
638                    .flat_map(|(k, v)| {
639                        let val = match v {
640                            Value::String(s) => s.clone(),
641                            other => other.to_string(),
642                        };
643                        vec![format!("--{k}"), val]
644                    })
645                    .collect();
646                crate::core::cli_executor::execute_with_gen(
647                    provider,
648                    &raw,
649                    &state.keyring,
650                    Some(&mcp_gen_ctx),
651                    Some(&state.auth_cache),
652                )
653                .await
654                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
655            } else {
656                match match provider.handler.as_str() {
657                    "xai" => {
658                        xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
659                    }
660                    _ => {
661                        http::execute_tool_with_gen(
662                            provider,
663                            _tool,
664                            &arguments,
665                            &state.keyring,
666                            Some(&mcp_gen_ctx),
667                            Some(&state.auth_cache),
668                        )
669                        .await
670                    }
671                } {
672                    Ok(val) => Ok(val),
673                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
674                }
675            };
676
677            match result {
678                Ok(value) => {
679                    let text = match &value {
680                        Value::String(s) => s.clone(),
681                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
682                    };
683                    let mcp_result = serde_json::json!({
684                        "content": [{"type": "text", "text": text}],
685                        "isError": false,
686                    });
687                    jsonrpc_success(id, mcp_result)
688                }
689                Err(e) => {
690                    let mcp_result = serde_json::json!({
691                        "content": [{"type": "text", "text": format!("Error: {e}")}],
692                        "isError": true,
693                    });
694                    jsonrpc_success(id, mcp_result)
695                }
696            }
697        }
698
699        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
700    }
701}
702
703fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
704    (
705        StatusCode::OK,
706        Json(serde_json::json!({
707            "jsonrpc": "2.0",
708            "id": id,
709            "result": result,
710        })),
711    )
712}
713
714fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
715    (
716        StatusCode::OK,
717        Json(serde_json::json!({
718            "jsonrpc": "2.0",
719            "id": id,
720            "error": {
721                "code": code,
722                "message": message,
723            }
724        })),
725    )
726}
727
728// ---------------------------------------------------------------------------
729// Skill endpoints
730// ---------------------------------------------------------------------------
731
732async fn handle_skills_list(
733    State(state): State<Arc<ProxyState>>,
734    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
735) -> impl IntoResponse {
736    if state.verbose {
737        eprintln!(
738            "[proxy] GET /skills category={:?} provider={:?} tool={:?} search={:?}",
739            query.category, query.provider, query.tool, query.search
740        );
741    }
742
743    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
744        state.skill_registry.search(search_query)
745    } else if let Some(cat) = &query.category {
746        state.skill_registry.skills_for_category(cat)
747    } else if let Some(prov) = &query.provider {
748        state.skill_registry.skills_for_provider(prov)
749    } else if let Some(t) = &query.tool {
750        state.skill_registry.skills_for_tool(t)
751    } else {
752        state.skill_registry.list_skills().iter().collect()
753    };
754
755    let json: Vec<Value> = skills
756        .iter()
757        .map(|s| {
758            serde_json::json!({
759                "name": s.name,
760                "version": s.version,
761                "description": s.description,
762                "tools": s.tools,
763                "providers": s.providers,
764                "categories": s.categories,
765                "hint": s.hint,
766            })
767        })
768        .collect();
769
770    (StatusCode::OK, Json(Value::Array(json)))
771}
772
773async fn handle_skill_detail(
774    State(state): State<Arc<ProxyState>>,
775    axum::extract::Path(name): axum::extract::Path<String>,
776    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
777) -> impl IntoResponse {
778    if state.verbose {
779        eprintln!(
780            "[proxy] GET /skills/{name} meta={:?} refs={:?}",
781            query.meta, query.refs
782        );
783    }
784
785    let skill_meta = match state.skill_registry.get_skill(&name) {
786        Some(s) => s,
787        None => {
788            return (
789                StatusCode::NOT_FOUND,
790                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
791            );
792        }
793    };
794
795    if query.meta.unwrap_or(false) {
796        return (
797            StatusCode::OK,
798            Json(serde_json::json!({
799                "name": skill_meta.name,
800                "version": skill_meta.version,
801                "description": skill_meta.description,
802                "author": skill_meta.author,
803                "tools": skill_meta.tools,
804                "providers": skill_meta.providers,
805                "categories": skill_meta.categories,
806                "keywords": skill_meta.keywords,
807                "hint": skill_meta.hint,
808                "depends_on": skill_meta.depends_on,
809                "suggests": skill_meta.suggests,
810                "license": skill_meta.license,
811                "compatibility": skill_meta.compatibility,
812                "allowed_tools": skill_meta.allowed_tools,
813                "format": skill_meta.format,
814            })),
815        );
816    }
817
818    let content = match state.skill_registry.read_content(&name) {
819        Ok(c) => c,
820        Err(e) => {
821            return (
822                StatusCode::INTERNAL_SERVER_ERROR,
823                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
824            );
825        }
826    };
827
828    let mut response = serde_json::json!({
829        "name": skill_meta.name,
830        "version": skill_meta.version,
831        "description": skill_meta.description,
832        "content": content,
833    });
834
835    if query.refs.unwrap_or(false) {
836        if let Ok(refs) = state.skill_registry.list_references(&name) {
837            response["references"] = serde_json::json!(refs);
838        }
839    }
840
841    (StatusCode::OK, Json(response))
842}
843
844async fn handle_skills_resolve(
845    State(state): State<Arc<ProxyState>>,
846    Json(req): Json<SkillResolveRequest>,
847) -> impl IntoResponse {
848    if state.verbose {
849        eprintln!("[proxy] POST /skills/resolve scopes={:?}", req.scopes);
850    }
851
852    let scopes = ScopeConfig {
853        scopes: req.scopes,
854        sub: String::new(),
855        expires_at: 0,
856        rate_config: None,
857    };
858
859    let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
860
861    let json: Vec<Value> = resolved
862        .iter()
863        .map(|s| {
864            serde_json::json!({
865                "name": s.name,
866                "version": s.version,
867                "description": s.description,
868                "tools": s.tools,
869                "providers": s.providers,
870                "categories": s.categories,
871            })
872        })
873        .collect();
874
875    (StatusCode::OK, Json(Value::Array(json)))
876}
877
878// --- Auth middleware ---
879
880/// JWT authentication middleware.
881///
882/// - /health and /.well-known/jwks.json → skip auth
883/// - JWT configured → validate Bearer token, attach claims to request extensions
884/// - No JWT configured → allow all (dev mode)
885async fn auth_middleware(
886    State(state): State<Arc<ProxyState>>,
887    mut req: HttpRequest<Body>,
888    next: Next,
889) -> Result<Response, StatusCode> {
890    let path = req.uri().path();
891
892    // Skip auth for public endpoints
893    if path == "/health" || path == "/.well-known/jwks.json" {
894        return Ok(next.run(req).await);
895    }
896
897    // If no JWT configured, allow all (dev mode)
898    let jwt_config = match &state.jwt_config {
899        Some(c) => c,
900        None => return Ok(next.run(req).await),
901    };
902
903    // Extract Authorization: Bearer <token>
904    let auth_header = req
905        .headers()
906        .get("authorization")
907        .and_then(|v| v.to_str().ok());
908
909    let token = match auth_header {
910        Some(header) if header.starts_with("Bearer ") => &header[7..],
911        _ => return Err(StatusCode::UNAUTHORIZED),
912    };
913
914    // Validate JWT
915    match jwt::validate(token, jwt_config) {
916        Ok(claims) => {
917            if state.verbose {
918                eprintln!(
919                    "[proxy] JWT validated: sub={} scopes={}",
920                    claims.sub, claims.scope
921                );
922            }
923            req.extensions_mut().insert(claims);
924            Ok(next.run(req).await)
925        }
926        Err(e) => {
927            if state.verbose {
928                eprintln!("[proxy] JWT validation failed: {e}");
929            }
930            Err(StatusCode::UNAUTHORIZED)
931        }
932    }
933}
934
935// --- Router builder ---
936
937/// Build the axum Router from a pre-constructed ProxyState.
938pub fn build_router(state: Arc<ProxyState>) -> Router {
939    Router::new()
940        .route("/call", post(handle_call))
941        .route("/help", post(handle_help))
942        .route("/mcp", post(handle_mcp))
943        .route("/skills", get(handle_skills_list))
944        .route("/skills/resolve", post(handle_skills_resolve))
945        .route("/skills/{name}", get(handle_skill_detail))
946        .route("/health", get(handle_health))
947        .route("/.well-known/jwks.json", get(handle_jwks))
948        .layer(middleware::from_fn_with_state(
949            state.clone(),
950            auth_middleware,
951        ))
952        .with_state(state)
953}
954
955// --- Server startup ---
956
957/// Start the proxy server.
958pub async fn run(
959    port: u16,
960    bind_addr: Option<String>,
961    ati_dir: PathBuf,
962    verbose: bool,
963    env_keys: bool,
964) -> Result<(), Box<dyn std::error::Error>> {
965    // Load manifests
966    let manifests_dir = ati_dir.join("manifests");
967    let registry = ManifestRegistry::load(&manifests_dir)?;
968
969    let tool_count = registry.list_public_tools().len();
970    let provider_count = registry.list_providers().len();
971
972    // Load keyring
973    let keyring_source;
974    let keyring = if env_keys {
975        // --env-keys: scan ATI_KEY_* environment variables
976        let kr = Keyring::from_env();
977        let key_names = kr.key_names();
978        eprintln!(
979            "  Loaded {} API keys from ATI_KEY_* env vars:",
980            key_names.len()
981        );
982        for name in &key_names {
983            eprintln!("    - {name}");
984        }
985        keyring_source = "env-vars (ATI_KEY_*)";
986        kr
987    } else {
988        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
989        let keyring_path = ati_dir.join("keyring.enc");
990        if keyring_path.exists() {
991            if let Ok(kr) = Keyring::load(&keyring_path) {
992                keyring_source = "keyring.enc (sealed key)";
993                kr
994            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
995                keyring_source = "keyring.enc (persistent key)";
996                kr
997            } else {
998                eprintln!("Warning: keyring.enc exists but could not be decrypted");
999                keyring_source = "empty (decryption failed)";
1000                Keyring::empty()
1001            }
1002        } else {
1003            let creds_path = ati_dir.join("credentials");
1004            if creds_path.exists() {
1005                match Keyring::load_credentials(&creds_path) {
1006                    Ok(kr) => {
1007                        keyring_source = "credentials (plaintext)";
1008                        kr
1009                    }
1010                    Err(e) => {
1011                        eprintln!("Warning: Failed to load credentials: {e}");
1012                        keyring_source = "empty (credentials error)";
1013                        Keyring::empty()
1014                    }
1015                }
1016            } else {
1017                eprintln!(
1018                    "Warning: No keyring.enc or credentials found — running without API keys"
1019                );
1020                eprintln!("Tools requiring authentication will fail.");
1021                keyring_source = "empty (no auth)";
1022                Keyring::empty()
1023            }
1024        }
1025    };
1026
1027    // Log MCP and OpenAPI providers
1028    let mcp_providers: Vec<(String, String)> = registry
1029        .list_mcp_providers()
1030        .iter()
1031        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1032        .collect();
1033    let mcp_count = mcp_providers.len();
1034    let openapi_providers: Vec<String> = registry
1035        .list_openapi_providers()
1036        .iter()
1037        .map(|p| p.name.clone())
1038        .collect();
1039    let openapi_count = openapi_providers.len();
1040
1041    // Load skill registry
1042    let skills_dir = ati_dir.join("skills");
1043    let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1044        if verbose {
1045            eprintln!("Warning: failed to load skills: {e}");
1046        }
1047        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1048    });
1049    let skill_count = skill_registry.skill_count();
1050
1051    // Load JWT config from environment
1052    let jwt_config = match jwt::config_from_env() {
1053        Ok(config) => config,
1054        Err(e) => {
1055            eprintln!("Warning: JWT config error: {e}");
1056            None
1057        }
1058    };
1059
1060    let auth_status = if jwt_config.is_some() {
1061        "JWT enabled"
1062    } else {
1063        "DISABLED (no JWT keys configured)"
1064    };
1065
1066    // Build JWKS for the endpoint
1067    let jwks_json = jwt_config.as_ref().and_then(|config| {
1068        config
1069            .public_key_pem
1070            .as_ref()
1071            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1072    });
1073
1074    let state = Arc::new(ProxyState {
1075        registry,
1076        skill_registry,
1077        keyring,
1078        verbose,
1079        jwt_config,
1080        jwks_json,
1081        auth_cache: AuthCache::new(),
1082    });
1083
1084    let app = build_router(state);
1085
1086    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1087        format!("{bind}:{port}").parse()?
1088    } else {
1089        SocketAddr::from(([127, 0, 0, 1], port))
1090    };
1091
1092    eprintln!("ATI proxy server v{}", env!("CARGO_PKG_VERSION"));
1093    eprintln!("  Listening on http://{addr}");
1094    eprintln!("  Auth: {auth_status}");
1095    eprintln!("  ATI dir: {}", ati_dir.display());
1096    eprintln!("  Tools: {tool_count}, Providers: {provider_count} ({mcp_count} MCP, {openapi_count} OpenAPI)");
1097    eprintln!("  Skills: {skill_count}");
1098    eprintln!("  Keyring: {keyring_source}");
1099    eprintln!("  Endpoints: /call, /help, /mcp, /skills, /skills/:name, /skills/resolve, /health, /.well-known/jwks.json");
1100    for (name, transport) in &mcp_providers {
1101        eprintln!("  MCP: {name} ({transport})");
1102    }
1103    for name in &openapi_providers {
1104        eprintln!("  OpenAPI: {name}");
1105    }
1106
1107    let listener = tokio::net::TcpListener::bind(addr).await?;
1108    axum::serve(listener, app).await?;
1109
1110    Ok(())
1111}
1112
1113/// Write an audit entry from the proxy server. Failures are silently ignored.
1114fn write_proxy_audit(
1115    call_req: &CallRequest,
1116    agent_sub: &str,
1117    duration: std::time::Duration,
1118    error: Option<&str>,
1119) {
1120    let entry = crate::core::audit::AuditEntry {
1121        ts: chrono::Utc::now().to_rfc3339(),
1122        tool: call_req.tool_name.clone(),
1123        args: crate::core::audit::sanitize_args(&serde_json::json!(call_req.args)),
1124        status: if error.is_some() {
1125            crate::core::audit::AuditStatus::Error
1126        } else {
1127            crate::core::audit::AuditStatus::Ok
1128        },
1129        duration_ms: duration.as_millis() as u64,
1130        agent_sub: agent_sub.to_string(),
1131        error: error.map(|s| s.to_string()),
1132        exit_code: None,
1133    };
1134    let _ = crate::core::audit::append(&entry);
1135}
1136
1137// --- Helpers ---
1138
1139const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1140
1141## Available Tools
1142{tools}
1143
1144{skills_section}
1145
1146Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1147
1148- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1149- If multiple steps are needed, walk through them briefly in order
1150- Mention important gotchas or parameter choices that matter
1151- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1152
1153Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1154
1155fn build_tool_context(
1156    tools: &[(
1157        &crate::core::manifest::Provider,
1158        &crate::core::manifest::Tool,
1159    )],
1160) -> String {
1161    let mut summaries = Vec::new();
1162    for (provider, tool) in tools {
1163        let mut summary = if let Some(cat) = &provider.category {
1164            format!(
1165                "- **{}** (provider: {}, category: {}): {}",
1166                tool.name, provider.name, cat, tool.description
1167            )
1168        } else {
1169            format!(
1170                "- **{}** (provider: {}): {}",
1171                tool.name, provider.name, tool.description
1172            )
1173        };
1174        if !tool.tags.is_empty() {
1175            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
1176        }
1177        // CLI tools: show passthrough usage
1178        if provider.is_cli() && tool.input_schema.is_none() {
1179            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1180            summary.push_str(&format!(
1181                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
1182                tool.name, cmd
1183            ));
1184        } else if let Some(schema) = &tool.input_schema {
1185            if let Some(props) = schema.get("properties") {
1186                if let Some(obj) = props.as_object() {
1187                    let params: Vec<String> = obj
1188                        .iter()
1189                        .filter(|(_, v)| {
1190                            v.get("x-ati-param-location").is_none()
1191                                || v.get("description").is_some()
1192                        })
1193                        .map(|(k, v)| {
1194                            let type_str =
1195                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1196                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1197                            format!("    --{k} ({type_str}): {desc}")
1198                        })
1199                        .collect();
1200                    if !params.is_empty() {
1201                        summary.push_str("\n  Parameters:\n");
1202                        summary.push_str(&params.join("\n"));
1203                    }
1204                }
1205            }
1206        }
1207        summaries.push(summary);
1208    }
1209    summaries.join("\n\n")
1210}
1211
1212/// Build a scoped system prompt for a specific tool or provider.
1213///
1214/// Returns None if the scope_name doesn't match any tool or provider.
1215fn build_scoped_prompt(
1216    scope_name: &str,
1217    registry: &ManifestRegistry,
1218    skills_section: &str,
1219) -> Option<String> {
1220    // Check if scope_name is a tool
1221    if let Some((provider, tool)) = registry.get_tool(scope_name) {
1222        let mut details = format!(
1223            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1224            tool.name, provider.name, provider.handler, tool.description
1225        );
1226        if let Some(cat) = &provider.category {
1227            details.push_str(&format!("**Category**: {}\n", cat));
1228        }
1229        if provider.is_cli() {
1230            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1231            details.push_str(&format!(
1232                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
1233                tool.name, cmd
1234            ));
1235        } else if let Some(schema) = &tool.input_schema {
1236            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1237                let required: Vec<String> = schema
1238                    .get("required")
1239                    .and_then(|r| r.as_array())
1240                    .map(|arr| {
1241                        arr.iter()
1242                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
1243                            .collect()
1244                    })
1245                    .unwrap_or_default();
1246                details.push_str("\n**Parameters**:\n");
1247                for (key, val) in props {
1248                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1249                    let desc = val
1250                        .get("description")
1251                        .and_then(|d| d.as_str())
1252                        .unwrap_or("");
1253                    let req = if required.contains(key) {
1254                        " **(required)**"
1255                    } else {
1256                        ""
1257                    };
1258                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1259                }
1260            }
1261        }
1262
1263        let prompt = format!(
1264            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1265            ## Tool Details\n{}\n\n{}\n\n\
1266            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1267            tool.name, details, skills_section
1268        );
1269        return Some(prompt);
1270    }
1271
1272    // Check if scope_name is a provider
1273    if registry.has_provider(scope_name) {
1274        let tools = registry.tools_by_provider(scope_name);
1275        if tools.is_empty() {
1276            return None;
1277        }
1278        let tools_context = build_tool_context(&tools);
1279        let prompt = format!(
1280            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1281            ## Tools in provider `{}`\n{}\n\n{}\n\n\
1282            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1283            scope_name, scope_name, tools_context, skills_section
1284        );
1285        return Some(prompt);
1286    }
1287
1288    None
1289}