Skip to main content

ati/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::State,
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34/// Shared state for the proxy server.
35pub struct ProxyState {
36    pub registry: ManifestRegistry,
37    pub skill_registry: SkillRegistry,
38    pub keyring: Keyring,
39    /// JWT validation config (None = auth disabled / dev mode).
40    pub jwt_config: Option<JwtConfig>,
41    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
42    pub jwks_json: Option<Value>,
43    /// Shared cache for dynamically generated auth credentials.
44    pub auth_cache: AuthCache,
45}
46
47// --- Request/Response types ---
48
49#[derive(Debug, Deserialize)]
50pub struct CallRequest {
51    pub tool_name: String,
52    /// Tool arguments — accepts a JSON object (key-value pairs) for HTTP/MCP/OpenAPI tools,
53    /// or a JSON array of strings / a single string for CLI tools.
54    /// The proxy auto-detects the handler type and routes accordingly.
55    #[serde(default = "default_args")]
56    pub args: Value,
57    /// Deprecated: use `args` with an array value instead.
58    /// Kept for backward compatibility — if present, takes precedence for CLI tools.
59    #[serde(default)]
60    pub raw_args: Option<Vec<String>>,
61}
62
63fn default_args() -> Value {
64    Value::Object(serde_json::Map::new())
65}
66
67impl CallRequest {
68    /// Extract args as a HashMap for HTTP/MCP/OpenAPI tools.
69    /// If `args` is a JSON object, returns its entries.
70    /// If `args` is something else (array, string), returns an empty map.
71    fn args_as_map(&self) -> HashMap<String, Value> {
72        match &self.args {
73            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
74            _ => HashMap::new(),
75        }
76    }
77
78    /// Extract positional args for CLI tools.
79    /// Priority: explicit `raw_args` field > `args` array > `args` string > `args._positional` > empty.
80    fn args_as_positional(&self) -> Vec<String> {
81        // Backward compat: explicit raw_args wins
82        if let Some(ref raw) = self.raw_args {
83            return raw.clone();
84        }
85        match &self.args {
86            // ["pr", "list", "--repo", "X"]
87            Value::Array(arr) => arr
88                .iter()
89                .map(|v| match v {
90                    Value::String(s) => s.clone(),
91                    other => other.to_string(),
92                })
93                .collect(),
94            // "pr list --repo X"
95            Value::String(s) => s.split_whitespace().map(String::from).collect(),
96            // {"_positional": ["pr", "list"]} or {"--key": "value"} converted to CLI flags
97            Value::Object(map) => {
98                if let Some(Value::Array(pos)) = map.get("_positional") {
99                    return pos
100                        .iter()
101                        .map(|v| match v {
102                            Value::String(s) => s.clone(),
103                            other => other.to_string(),
104                        })
105                        .collect();
106                }
107                // Convert map entries to --key value pairs
108                let mut result = Vec::new();
109                for (k, v) in map {
110                    result.push(format!("--{k}"));
111                    match v {
112                        Value::String(s) => result.push(s.clone()),
113                        Value::Bool(true) => {} // flag, no value needed
114                        other => result.push(other.to_string()),
115                    }
116                }
117                result
118            }
119            _ => Vec::new(),
120        }
121    }
122}
123
124#[derive(Debug, Serialize)]
125pub struct CallResponse {
126    pub result: Value,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub error: Option<String>,
129}
130
131#[derive(Debug, Deserialize)]
132pub struct HelpRequest {
133    pub query: String,
134    #[serde(default)]
135    pub tool: Option<String>,
136}
137
138#[derive(Debug, Serialize)]
139pub struct HelpResponse {
140    pub content: String,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub error: Option<String>,
143}
144
145#[derive(Debug, Serialize)]
146pub struct HealthResponse {
147    pub status: String,
148    pub version: String,
149    pub tools: usize,
150    pub providers: usize,
151    pub skills: usize,
152    pub auth: String,
153}
154
155// --- Skill endpoint types ---
156
157#[derive(Debug, Deserialize)]
158pub struct SkillsQuery {
159    #[serde(default)]
160    pub category: Option<String>,
161    #[serde(default)]
162    pub provider: Option<String>,
163    #[serde(default)]
164    pub tool: Option<String>,
165    #[serde(default)]
166    pub search: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct SkillDetailQuery {
171    #[serde(default)]
172    pub meta: Option<bool>,
173    #[serde(default)]
174    pub refs: Option<bool>,
175}
176
177#[derive(Debug, Deserialize)]
178pub struct SkillResolveRequest {
179    pub scopes: Vec<String>,
180    /// When true, include SKILL.md content in each resolved skill.
181    #[serde(default)]
182    pub include_content: bool,
183}
184
185#[derive(Debug, Deserialize)]
186pub struct SkillBundleBatchRequest {
187    pub names: Vec<String>,
188}
189
190// --- Tool endpoint types ---
191
192#[derive(Debug, Deserialize)]
193pub struct ToolsQuery {
194    #[serde(default)]
195    pub provider: Option<String>,
196    #[serde(default)]
197    pub search: Option<String>,
198}
199
200// --- Handlers ---
201
202async fn handle_call(
203    State(state): State<Arc<ProxyState>>,
204    req: HttpRequest<Body>,
205) -> impl IntoResponse {
206    // Extract JWT claims from request extensions (set by auth middleware)
207    let claims = req.extensions().get::<TokenClaims>().cloned();
208
209    // Parse request body
210    let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
211        Ok(b) => b,
212        Err(e) => {
213            return (
214                StatusCode::BAD_REQUEST,
215                Json(CallResponse {
216                    result: Value::Null,
217                    error: Some(format!("Failed to read request body: {e}")),
218                }),
219            );
220        }
221    };
222
223    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
224        Ok(r) => r,
225        Err(e) => {
226            return (
227                StatusCode::UNPROCESSABLE_ENTITY,
228                Json(CallResponse {
229                    result: Value::Null,
230                    error: Some(format!("Invalid request: {e}")),
231                }),
232            );
233        }
234    };
235
236    tracing::debug!(
237        tool = %call_req.tool_name,
238        args = ?call_req.args,
239        "POST /call"
240    );
241
242    // Look up tool in registry
243    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
244        Some(pt) => pt,
245        None => {
246            return (
247                StatusCode::NOT_FOUND,
248                Json(CallResponse {
249                    result: Value::Null,
250                    error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
251                }),
252            );
253        }
254    };
255
256    // Scope enforcement from JWT claims
257    if let Some(tool_scope) = &tool.scope {
258        let scopes = match &claims {
259            Some(c) => ScopeConfig::from_jwt(c),
260            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
261            None => {
262                return (
263                    StatusCode::FORBIDDEN,
264                    Json(CallResponse {
265                        result: Value::Null,
266                        error: Some("Authentication required — no JWT provided".into()),
267                    }),
268                );
269            }
270        };
271
272        if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
273            return (
274                StatusCode::FORBIDDEN,
275                Json(CallResponse {
276                    result: Value::Null,
277                    error: Some(format!("Access denied: {e}")),
278                }),
279            );
280        }
281    }
282
283    // Rate limit check
284    {
285        let scopes = match &claims {
286            Some(c) => ScopeConfig::from_jwt(c),
287            None => ScopeConfig::unrestricted(),
288        };
289        if let Some(ref rate_config) = scopes.rate_config {
290            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
291                return (
292                    StatusCode::TOO_MANY_REQUESTS,
293                    Json(CallResponse {
294                        result: Value::Null,
295                        error: Some(format!("{e}")),
296                    }),
297                );
298            }
299        }
300    }
301
302    // Build auth generator context from JWT claims
303    let gen_ctx = GenContext {
304        jwt_sub: claims
305            .as_ref()
306            .map(|c| c.sub.clone())
307            .unwrap_or_else(|| "dev".into()),
308        jwt_scope: claims
309            .as_ref()
310            .map(|c| c.scope.clone())
311            .unwrap_or_else(|| "*".into()),
312        tool_name: call_req.tool_name.clone(),
313        timestamp: crate::core::jwt::now_secs(),
314    };
315
316    // Execute tool call — dispatch based on handler type, with timing for audit
317    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
318    let start = std::time::Instant::now();
319
320    let response = match provider.handler.as_str() {
321        "mcp" => {
322            let args_map = call_req.args_as_map();
323            match mcp_client::execute_with_gen(
324                provider,
325                &call_req.tool_name,
326                &args_map,
327                &state.keyring,
328                Some(&gen_ctx),
329                Some(&state.auth_cache),
330            )
331            .await
332            {
333                Ok(result) => (
334                    StatusCode::OK,
335                    Json(CallResponse {
336                        result,
337                        error: None,
338                    }),
339                ),
340                Err(e) => (
341                    StatusCode::BAD_GATEWAY,
342                    Json(CallResponse {
343                        result: Value::Null,
344                        error: Some(format!("MCP error: {e}")),
345                    }),
346                ),
347            }
348        }
349        "cli" => {
350            let positional = call_req.args_as_positional();
351            match crate::core::cli_executor::execute_with_gen(
352                provider,
353                &positional,
354                &state.keyring,
355                Some(&gen_ctx),
356                Some(&state.auth_cache),
357            )
358            .await
359            {
360                Ok(result) => (
361                    StatusCode::OK,
362                    Json(CallResponse {
363                        result,
364                        error: None,
365                    }),
366                ),
367                Err(e) => (
368                    StatusCode::BAD_GATEWAY,
369                    Json(CallResponse {
370                        result: Value::Null,
371                        error: Some(format!("CLI error: {e}")),
372                    }),
373                ),
374            }
375        }
376        _ => {
377            let args_map = call_req.args_as_map();
378            let raw_response = match match provider.handler.as_str() {
379                "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
380                _ => {
381                    http::execute_tool_with_gen(
382                        provider,
383                        tool,
384                        &args_map,
385                        &state.keyring,
386                        Some(&gen_ctx),
387                        Some(&state.auth_cache),
388                    )
389                    .await
390                }
391            } {
392                Ok(resp) => resp,
393                Err(e) => {
394                    let duration = start.elapsed();
395                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
396                    return (
397                        StatusCode::BAD_GATEWAY,
398                        Json(CallResponse {
399                            result: Value::Null,
400                            error: Some(format!("Upstream API error: {e}")),
401                        }),
402                    );
403                }
404            };
405
406            let processed = match response::process_response(&raw_response, tool.response.as_ref())
407            {
408                Ok(p) => p,
409                Err(e) => {
410                    let duration = start.elapsed();
411                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
412                    return (
413                        StatusCode::INTERNAL_SERVER_ERROR,
414                        Json(CallResponse {
415                            result: raw_response,
416                            error: Some(format!("Response processing error: {e}")),
417                        }),
418                    );
419                }
420            };
421
422            (
423                StatusCode::OK,
424                Json(CallResponse {
425                    result: processed,
426                    error: None,
427                }),
428            )
429        }
430    };
431
432    let duration = start.elapsed();
433    let error_msg = response.1.error.as_deref();
434    write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
435
436    response
437}
438
439async fn handle_help(
440    State(state): State<Arc<ProxyState>>,
441    Json(req): Json<HelpRequest>,
442) -> impl IntoResponse {
443    tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
444
445    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
446        Some(pt) => pt,
447        None => {
448            return (
449                StatusCode::SERVICE_UNAVAILABLE,
450                Json(HelpResponse {
451                    content: String::new(),
452                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
453                }),
454            );
455        }
456    };
457
458    let api_key = match llm_provider
459        .auth_key_name
460        .as_deref()
461        .and_then(|k| state.keyring.get(k))
462    {
463        Some(key) => key.to_string(),
464        None => {
465            return (
466                StatusCode::SERVICE_UNAVAILABLE,
467                Json(HelpResponse {
468                    content: String::new(),
469                    error: Some("LLM API key not found in keyring".into()),
470                }),
471            );
472        }
473    };
474
475    let scopes = ScopeConfig::unrestricted();
476    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
477    let skills_section = if resolved_skills.is_empty() {
478        String::new()
479    } else {
480        format!(
481            "## Available Skills (methodology guides)\n{}",
482            skill::build_skill_context(&resolved_skills)
483        )
484    };
485
486    // Build system prompt — scoped or unscoped
487    let system_prompt = if let Some(ref tool_name) = req.tool {
488        // Scoped mode: narrow tools to the specified tool or provider
489        match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
490            Some(prompt) => prompt,
491            None => {
492                // Fall back to unscoped if tool/provider not found
493                tracing::debug!(scope = %tool_name, "scope not found, falling back to unscoped");
494                let all_tools = state.registry.list_public_tools();
495                let tools_context = build_tool_context(&all_tools);
496                HELP_SYSTEM_PROMPT
497                    .replace("{tools}", &tools_context)
498                    .replace("{skills_section}", &skills_section)
499            }
500        }
501    } else {
502        let all_tools = state.registry.list_public_tools();
503        let tools_context = build_tool_context(&all_tools);
504        HELP_SYSTEM_PROMPT
505            .replace("{tools}", &tools_context)
506            .replace("{skills_section}", &skills_section)
507    };
508
509    let request_body = serde_json::json!({
510        "model": "zai-glm-4.7",
511        "messages": [
512            {"role": "system", "content": system_prompt},
513            {"role": "user", "content": req.query}
514        ],
515        "max_completion_tokens": 1536,
516        "temperature": 0.3
517    });
518
519    let client = reqwest::Client::new();
520    let url = format!(
521        "{}{}",
522        llm_provider.base_url.trim_end_matches('/'),
523        llm_tool.endpoint
524    );
525
526    let response = match client
527        .post(&url)
528        .bearer_auth(&api_key)
529        .json(&request_body)
530        .send()
531        .await
532    {
533        Ok(r) => r,
534        Err(e) => {
535            return (
536                StatusCode::BAD_GATEWAY,
537                Json(HelpResponse {
538                    content: String::new(),
539                    error: Some(format!("LLM request failed: {e}")),
540                }),
541            );
542        }
543    };
544
545    if !response.status().is_success() {
546        let status = response.status();
547        let body = response.text().await.unwrap_or_default();
548        return (
549            StatusCode::BAD_GATEWAY,
550            Json(HelpResponse {
551                content: String::new(),
552                error: Some(format!("LLM API error ({status}): {body}")),
553            }),
554        );
555    }
556
557    let body: Value = match response.json().await {
558        Ok(b) => b,
559        Err(e) => {
560            return (
561                StatusCode::INTERNAL_SERVER_ERROR,
562                Json(HelpResponse {
563                    content: String::new(),
564                    error: Some(format!("Failed to parse LLM response: {e}")),
565                }),
566            );
567        }
568    };
569
570    let content = body
571        .pointer("/choices/0/message/content")
572        .and_then(|c| c.as_str())
573        .unwrap_or("No response from LLM")
574        .to_string();
575
576    (
577        StatusCode::OK,
578        Json(HelpResponse {
579            content,
580            error: None,
581        }),
582    )
583}
584
585async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
586    let auth = if state.jwt_config.is_some() {
587        "jwt"
588    } else {
589        "disabled"
590    };
591
592    Json(HealthResponse {
593        status: "ok".into(),
594        version: env!("CARGO_PKG_VERSION").into(),
595        tools: state.registry.list_public_tools().len(),
596        providers: state.registry.list_providers().len(),
597        skills: state.skill_registry.skill_count(),
598        auth: auth.into(),
599    })
600}
601
602/// GET /.well-known/jwks.json — serves the public key for JWT validation.
603async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
604    match &state.jwks_json {
605        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
606        None => (
607            StatusCode::NOT_FOUND,
608            Json(serde_json::json!({"error": "JWKS not configured"})),
609        ),
610    }
611}
612
613// ---------------------------------------------------------------------------
614// POST /mcp — MCP JSON-RPC proxy endpoint
615// ---------------------------------------------------------------------------
616
617async fn handle_mcp(
618    State(state): State<Arc<ProxyState>>,
619    Json(msg): Json<Value>,
620) -> impl IntoResponse {
621    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
622    let id = msg.get("id").cloned();
623
624    tracing::debug!(%method, "POST /mcp");
625
626    match method {
627        "initialize" => {
628            let result = serde_json::json!({
629                "protocolVersion": "2025-03-26",
630                "capabilities": {
631                    "tools": { "listChanged": false }
632                },
633                "serverInfo": {
634                    "name": "ati-proxy",
635                    "version": env!("CARGO_PKG_VERSION")
636                }
637            });
638            jsonrpc_success(id, result)
639        }
640
641        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
642
643        "tools/list" => {
644            let all_tools = state.registry.list_public_tools();
645            let mcp_tools: Vec<Value> = all_tools
646                .iter()
647                .map(|(_provider, tool)| {
648                    serde_json::json!({
649                        "name": tool.name,
650                        "description": tool.description,
651                        "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
652                            "type": "object",
653                            "properties": {}
654                        }))
655                    })
656                })
657                .collect();
658
659            let result = serde_json::json!({
660                "tools": mcp_tools,
661            });
662            jsonrpc_success(id, result)
663        }
664
665        "tools/call" => {
666            let params = msg.get("params").cloned().unwrap_or(Value::Null);
667            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
668            let arguments: HashMap<String, Value> = params
669                .get("arguments")
670                .and_then(|a| serde_json::from_value(a.clone()).ok())
671                .unwrap_or_default();
672
673            if tool_name.is_empty() {
674                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
675            }
676
677            let (provider, _tool) = match state.registry.get_tool(tool_name) {
678                Some(pt) => pt,
679                None => {
680                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
681                }
682            };
683
684            tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
685
686            let mcp_gen_ctx = GenContext {
687                jwt_sub: "dev".into(),
688                jwt_scope: "*".into(),
689                tool_name: tool_name.to_string(),
690                timestamp: crate::core::jwt::now_secs(),
691            };
692
693            let result = if provider.is_mcp() {
694                mcp_client::execute_with_gen(
695                    provider,
696                    tool_name,
697                    &arguments,
698                    &state.keyring,
699                    Some(&mcp_gen_ctx),
700                    Some(&state.auth_cache),
701                )
702                .await
703            } else if provider.is_cli() {
704                // Convert arguments map to CLI-style args for MCP passthrough
705                let raw: Vec<String> = arguments
706                    .iter()
707                    .flat_map(|(k, v)| {
708                        let val = match v {
709                            Value::String(s) => s.clone(),
710                            other => other.to_string(),
711                        };
712                        vec![format!("--{k}"), val]
713                    })
714                    .collect();
715                crate::core::cli_executor::execute_with_gen(
716                    provider,
717                    &raw,
718                    &state.keyring,
719                    Some(&mcp_gen_ctx),
720                    Some(&state.auth_cache),
721                )
722                .await
723                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
724            } else {
725                match match provider.handler.as_str() {
726                    "xai" => {
727                        xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
728                    }
729                    _ => {
730                        http::execute_tool_with_gen(
731                            provider,
732                            _tool,
733                            &arguments,
734                            &state.keyring,
735                            Some(&mcp_gen_ctx),
736                            Some(&state.auth_cache),
737                        )
738                        .await
739                    }
740                } {
741                    Ok(val) => Ok(val),
742                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
743                }
744            };
745
746            match result {
747                Ok(value) => {
748                    let text = match &value {
749                        Value::String(s) => s.clone(),
750                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
751                    };
752                    let mcp_result = serde_json::json!({
753                        "content": [{"type": "text", "text": text}],
754                        "isError": false,
755                    });
756                    jsonrpc_success(id, mcp_result)
757                }
758                Err(e) => {
759                    let mcp_result = serde_json::json!({
760                        "content": [{"type": "text", "text": format!("Error: {e}")}],
761                        "isError": true,
762                    });
763                    jsonrpc_success(id, mcp_result)
764                }
765            }
766        }
767
768        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
769    }
770}
771
772fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
773    (
774        StatusCode::OK,
775        Json(serde_json::json!({
776            "jsonrpc": "2.0",
777            "id": id,
778            "result": result,
779        })),
780    )
781}
782
783fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
784    (
785        StatusCode::OK,
786        Json(serde_json::json!({
787            "jsonrpc": "2.0",
788            "id": id,
789            "error": {
790                "code": code,
791                "message": message,
792            }
793        })),
794    )
795}
796
797// ---------------------------------------------------------------------------
798// Tool endpoints
799// ---------------------------------------------------------------------------
800
801/// GET /tools — list available tools with optional filters.
802async fn handle_tools_list(
803    State(state): State<Arc<ProxyState>>,
804    axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
805) -> impl IntoResponse {
806    tracing::debug!(
807        provider = ?query.provider,
808        search = ?query.search,
809        "GET /tools"
810    );
811
812    let all_tools = state.registry.list_public_tools();
813
814    let tools: Vec<Value> = all_tools
815        .iter()
816        .filter(|(provider, tool)| {
817            if let Some(ref p) = query.provider {
818                if provider.name != *p {
819                    return false;
820                }
821            }
822            if let Some(ref q) = query.search {
823                let q = q.to_lowercase();
824                let name_match = tool.name.to_lowercase().contains(&q);
825                let desc_match = tool.description.to_lowercase().contains(&q);
826                let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
827                if !name_match && !desc_match && !tag_match {
828                    return false;
829                }
830            }
831            true
832        })
833        .map(|(provider, tool)| {
834            serde_json::json!({
835                "name": tool.name,
836                "description": tool.description,
837                "provider": provider.name,
838                "method": format!("{:?}", tool.method),
839                "tags": tool.tags,
840                "input_schema": tool.input_schema,
841            })
842        })
843        .collect();
844
845    (StatusCode::OK, Json(Value::Array(tools)))
846}
847
848/// GET /tools/:name — get detailed info about a specific tool.
849async fn handle_tool_info(
850    State(state): State<Arc<ProxyState>>,
851    axum::extract::Path(name): axum::extract::Path<String>,
852) -> impl IntoResponse {
853    tracing::debug!(tool = %name, "GET /tools/:name");
854
855    match state.registry.get_tool(&name) {
856        Some((provider, tool)) => (
857            StatusCode::OK,
858            Json(serde_json::json!({
859                "name": tool.name,
860                "description": tool.description,
861                "provider": provider.name,
862                "method": format!("{:?}", tool.method),
863                "endpoint": tool.endpoint,
864                "tags": tool.tags,
865                "hint": tool.hint,
866                "input_schema": tool.input_schema,
867                "scope": tool.scope,
868            })),
869        ),
870        None => (
871            StatusCode::NOT_FOUND,
872            Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
873        ),
874    }
875}
876
877// ---------------------------------------------------------------------------
878// Skill endpoints
879// ---------------------------------------------------------------------------
880
881async fn handle_skills_list(
882    State(state): State<Arc<ProxyState>>,
883    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
884) -> impl IntoResponse {
885    tracing::debug!(
886        category = ?query.category,
887        provider = ?query.provider,
888        tool = ?query.tool,
889        search = ?query.search,
890        "GET /skills"
891    );
892
893    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
894        state.skill_registry.search(search_query)
895    } else if let Some(cat) = &query.category {
896        state.skill_registry.skills_for_category(cat)
897    } else if let Some(prov) = &query.provider {
898        state.skill_registry.skills_for_provider(prov)
899    } else if let Some(t) = &query.tool {
900        state.skill_registry.skills_for_tool(t)
901    } else {
902        state.skill_registry.list_skills().iter().collect()
903    };
904
905    let json: Vec<Value> = skills
906        .iter()
907        .map(|s| {
908            serde_json::json!({
909                "name": s.name,
910                "version": s.version,
911                "description": s.description,
912                "tools": s.tools,
913                "providers": s.providers,
914                "categories": s.categories,
915                "hint": s.hint,
916            })
917        })
918        .collect();
919
920    (StatusCode::OK, Json(Value::Array(json)))
921}
922
923async fn handle_skill_detail(
924    State(state): State<Arc<ProxyState>>,
925    axum::extract::Path(name): axum::extract::Path<String>,
926    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
927) -> impl IntoResponse {
928    tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
929
930    let skill_meta = match state.skill_registry.get_skill(&name) {
931        Some(s) => s,
932        None => {
933            return (
934                StatusCode::NOT_FOUND,
935                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
936            );
937        }
938    };
939
940    if query.meta.unwrap_or(false) {
941        return (
942            StatusCode::OK,
943            Json(serde_json::json!({
944                "name": skill_meta.name,
945                "version": skill_meta.version,
946                "description": skill_meta.description,
947                "author": skill_meta.author,
948                "tools": skill_meta.tools,
949                "providers": skill_meta.providers,
950                "categories": skill_meta.categories,
951                "keywords": skill_meta.keywords,
952                "hint": skill_meta.hint,
953                "depends_on": skill_meta.depends_on,
954                "suggests": skill_meta.suggests,
955                "license": skill_meta.license,
956                "compatibility": skill_meta.compatibility,
957                "allowed_tools": skill_meta.allowed_tools,
958                "format": skill_meta.format,
959            })),
960        );
961    }
962
963    let content = match state.skill_registry.read_content(&name) {
964        Ok(c) => c,
965        Err(e) => {
966            return (
967                StatusCode::INTERNAL_SERVER_ERROR,
968                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
969            );
970        }
971    };
972
973    let mut response = serde_json::json!({
974        "name": skill_meta.name,
975        "version": skill_meta.version,
976        "description": skill_meta.description,
977        "content": content,
978    });
979
980    if query.refs.unwrap_or(false) {
981        if let Ok(refs) = state.skill_registry.list_references(&name) {
982            response["references"] = serde_json::json!(refs);
983        }
984    }
985
986    (StatusCode::OK, Json(response))
987}
988
989/// GET /skills/:name/bundle — return all files in a skill directory.
990/// Response: `{"name": "...", "files": {"SKILL.md": "...", "scripts/generate.sh": "...", ...}}`
991/// Binary files are base64-encoded; text files are returned as-is.
992async fn handle_skill_bundle(
993    State(state): State<Arc<ProxyState>>,
994    axum::extract::Path(name): axum::extract::Path<String>,
995) -> impl IntoResponse {
996    tracing::debug!(skill = %name, "GET /skills/:name/bundle");
997
998    let files = match state.skill_registry.bundle_files(&name) {
999        Ok(f) => f,
1000        Err(_) => {
1001            return (
1002                StatusCode::NOT_FOUND,
1003                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1004            );
1005        }
1006    };
1007
1008    // Convert bytes to strings (UTF-8 text) or base64 for binary files
1009    let mut file_map = serde_json::Map::new();
1010    for (path, data) in &files {
1011        match std::str::from_utf8(data) {
1012            Ok(text) => {
1013                file_map.insert(path.clone(), Value::String(text.to_string()));
1014            }
1015            Err(_) => {
1016                // Binary file — base64 encode
1017                use base64::Engine;
1018                let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1019                file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1020            }
1021        }
1022    }
1023
1024    (
1025        StatusCode::OK,
1026        Json(serde_json::json!({
1027            "name": name,
1028            "files": file_map,
1029        })),
1030    )
1031}
1032
1033/// POST /skills/bundle — return all files for multiple skills in one response.
1034/// Request: `{"names": ["fal-generate", "compliance-screening"]}`
1035/// Response: `{"skills": {...}, "missing": [...]}`
1036async fn handle_skills_bundle_batch(
1037    State(state): State<Arc<ProxyState>>,
1038    Json(req): Json<SkillBundleBatchRequest>,
1039) -> impl IntoResponse {
1040    const MAX_BATCH: usize = 50;
1041    if req.names.len() > MAX_BATCH {
1042        return (
1043            StatusCode::BAD_REQUEST,
1044            Json(
1045                serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
1046            ),
1047        );
1048    }
1049
1050    tracing::debug!(names = ?req.names, "POST /skills/bundle");
1051
1052    let mut result = serde_json::Map::new();
1053    let mut missing: Vec<String> = Vec::new();
1054
1055    for name in &req.names {
1056        let files = match state.skill_registry.bundle_files(name) {
1057            Ok(f) => f,
1058            Err(_) => {
1059                missing.push(name.clone());
1060                continue;
1061            }
1062        };
1063
1064        let mut file_map = serde_json::Map::new();
1065        for (path, data) in &files {
1066            match std::str::from_utf8(data) {
1067                Ok(text) => {
1068                    file_map.insert(path.clone(), Value::String(text.to_string()));
1069                }
1070                Err(_) => {
1071                    use base64::Engine;
1072                    let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1073                    file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1074                }
1075            }
1076        }
1077
1078        result.insert(name.clone(), serde_json::json!({ "files": file_map }));
1079    }
1080
1081    (
1082        StatusCode::OK,
1083        Json(serde_json::json!({ "skills": result, "missing": missing })),
1084    )
1085}
1086
1087async fn handle_skills_resolve(
1088    State(state): State<Arc<ProxyState>>,
1089    Json(req): Json<SkillResolveRequest>,
1090) -> impl IntoResponse {
1091    tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1092
1093    let include_content = req.include_content;
1094    let scopes = ScopeConfig {
1095        scopes: req.scopes,
1096        sub: String::new(),
1097        expires_at: 0,
1098        rate_config: None,
1099    };
1100
1101    let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
1102
1103    let json: Vec<Value> = resolved
1104        .iter()
1105        .map(|s| {
1106            let mut entry = serde_json::json!({
1107                "name": s.name,
1108                "version": s.version,
1109                "description": s.description,
1110                "tools": s.tools,
1111                "providers": s.providers,
1112                "categories": s.categories,
1113            });
1114            if include_content {
1115                if let Ok(content) = state.skill_registry.read_content(&s.name) {
1116                    entry["content"] = Value::String(content);
1117                }
1118            }
1119            entry
1120        })
1121        .collect();
1122
1123    (StatusCode::OK, Json(Value::Array(json)))
1124}
1125
1126// --- Auth middleware ---
1127
1128/// JWT authentication middleware.
1129///
1130/// - /health and /.well-known/jwks.json → skip auth
1131/// - JWT configured → validate Bearer token, attach claims to request extensions
1132/// - No JWT configured → allow all (dev mode)
1133async fn auth_middleware(
1134    State(state): State<Arc<ProxyState>>,
1135    mut req: HttpRequest<Body>,
1136    next: Next,
1137) -> Result<Response, StatusCode> {
1138    let path = req.uri().path();
1139
1140    // Skip auth for public endpoints
1141    if path == "/health" || path == "/.well-known/jwks.json" {
1142        return Ok(next.run(req).await);
1143    }
1144
1145    // If no JWT configured, allow all (dev mode)
1146    let jwt_config = match &state.jwt_config {
1147        Some(c) => c,
1148        None => return Ok(next.run(req).await),
1149    };
1150
1151    // Extract Authorization: Bearer <token>
1152    let auth_header = req
1153        .headers()
1154        .get("authorization")
1155        .and_then(|v| v.to_str().ok());
1156
1157    let token = match auth_header {
1158        Some(header) if header.starts_with("Bearer ") => &header[7..],
1159        _ => return Err(StatusCode::UNAUTHORIZED),
1160    };
1161
1162    // Validate JWT
1163    match jwt::validate(token, jwt_config) {
1164        Ok(claims) => {
1165            tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1166            req.extensions_mut().insert(claims);
1167            Ok(next.run(req).await)
1168        }
1169        Err(e) => {
1170            tracing::debug!(error = %e, "JWT validation failed");
1171            Err(StatusCode::UNAUTHORIZED)
1172        }
1173    }
1174}
1175
1176// --- Router builder ---
1177
1178/// Build the axum Router from a pre-constructed ProxyState.
1179pub fn build_router(state: Arc<ProxyState>) -> Router {
1180    Router::new()
1181        .route("/call", post(handle_call))
1182        .route("/help", post(handle_help))
1183        .route("/mcp", post(handle_mcp))
1184        .route("/tools", get(handle_tools_list))
1185        .route("/tools/{name}", get(handle_tool_info))
1186        .route("/skills", get(handle_skills_list))
1187        .route("/skills/resolve", post(handle_skills_resolve))
1188        .route("/skills/bundle", post(handle_skills_bundle_batch))
1189        .route("/skills/{name}", get(handle_skill_detail))
1190        .route("/skills/{name}/bundle", get(handle_skill_bundle))
1191        .route("/health", get(handle_health))
1192        .route("/.well-known/jwks.json", get(handle_jwks))
1193        .layer(middleware::from_fn_with_state(
1194            state.clone(),
1195            auth_middleware,
1196        ))
1197        .with_state(state)
1198}
1199
1200// --- Server startup ---
1201
1202/// Start the proxy server.
1203pub async fn run(
1204    port: u16,
1205    bind_addr: Option<String>,
1206    ati_dir: PathBuf,
1207    _verbose: bool,
1208    env_keys: bool,
1209) -> Result<(), Box<dyn std::error::Error>> {
1210    // Load manifests
1211    let manifests_dir = ati_dir.join("manifests");
1212    let registry = ManifestRegistry::load(&manifests_dir)?;
1213
1214    let tool_count = registry.list_public_tools().len();
1215    let provider_count = registry.list_providers().len();
1216
1217    // Load keyring
1218    let keyring_source;
1219    let keyring = if env_keys {
1220        // --env-keys: scan ATI_KEY_* environment variables
1221        let kr = Keyring::from_env();
1222        let key_names = kr.key_names();
1223        tracing::info!(
1224            count = key_names.len(),
1225            "loaded API keys from ATI_KEY_* env vars"
1226        );
1227        for name in &key_names {
1228            tracing::debug!(key = %name, "env key loaded");
1229        }
1230        keyring_source = "env-vars (ATI_KEY_*)";
1231        kr
1232    } else {
1233        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
1234        let keyring_path = ati_dir.join("keyring.enc");
1235        if keyring_path.exists() {
1236            if let Ok(kr) = Keyring::load(&keyring_path) {
1237                keyring_source = "keyring.enc (sealed key)";
1238                kr
1239            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1240                keyring_source = "keyring.enc (persistent key)";
1241                kr
1242            } else {
1243                tracing::warn!("keyring.enc exists but could not be decrypted");
1244                keyring_source = "empty (decryption failed)";
1245                Keyring::empty()
1246            }
1247        } else {
1248            let creds_path = ati_dir.join("credentials");
1249            if creds_path.exists() {
1250                match Keyring::load_credentials(&creds_path) {
1251                    Ok(kr) => {
1252                        keyring_source = "credentials (plaintext)";
1253                        kr
1254                    }
1255                    Err(e) => {
1256                        tracing::warn!(error = %e, "failed to load credentials");
1257                        keyring_source = "empty (credentials error)";
1258                        Keyring::empty()
1259                    }
1260                }
1261            } else {
1262                tracing::warn!("no keyring.enc or credentials found — running without API keys");
1263                tracing::warn!("tools requiring authentication will fail");
1264                keyring_source = "empty (no auth)";
1265                Keyring::empty()
1266            }
1267        }
1268    };
1269
1270    // Log MCP and OpenAPI providers
1271    let mcp_providers: Vec<(String, String)> = registry
1272        .list_mcp_providers()
1273        .iter()
1274        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1275        .collect();
1276    let mcp_count = mcp_providers.len();
1277    let openapi_providers: Vec<String> = registry
1278        .list_openapi_providers()
1279        .iter()
1280        .map(|p| p.name.clone())
1281        .collect();
1282    let openapi_count = openapi_providers.len();
1283
1284    // Load skill registry (local + optional GCS)
1285    let skills_dir = ati_dir.join("skills");
1286    let mut skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1287        tracing::warn!(error = %e, "failed to load skills");
1288        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1289    });
1290
1291    // Load GCS skills if ATI_SKILL_REGISTRY is set
1292    if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1293        if let Some(bucket) = registry_url.strip_prefix("gcs://") {
1294            let cred_key = "gcp_credentials";
1295            if let Some(cred_json) = keyring.get(cred_key) {
1296                match crate::core::gcs::GcsClient::new(bucket.to_string(), cred_json) {
1297                    Ok(client) => match crate::core::gcs::GcsSkillSource::load(&client).await {
1298                        Ok(gcs_source) => {
1299                            let gcs_count = gcs_source.skill_count();
1300                            skill_registry.merge(gcs_source);
1301                            tracing::info!(
1302                                bucket = %bucket,
1303                                skills = gcs_count,
1304                                "loaded skills from GCS registry"
1305                            );
1306                        }
1307                        Err(e) => {
1308                            tracing::warn!(error = %e, bucket = %bucket, "failed to load GCS skills");
1309                        }
1310                    },
1311                    Err(e) => {
1312                        tracing::warn!(error = %e, "failed to init GCS client");
1313                    }
1314                }
1315            } else {
1316                tracing::warn!(
1317                    key = %cred_key,
1318                    "ATI_SKILL_REGISTRY set but GCS credentials not found in keyring"
1319                );
1320            }
1321        } else {
1322            tracing::warn!(
1323                url = %registry_url,
1324                "unsupported skill registry scheme (only gcs:// is supported)"
1325            );
1326        }
1327    }
1328
1329    let skill_count = skill_registry.skill_count();
1330
1331    // Load JWT config from environment
1332    let jwt_config = match jwt::config_from_env() {
1333        Ok(config) => config,
1334        Err(e) => {
1335            tracing::warn!(error = %e, "JWT config error");
1336            None
1337        }
1338    };
1339
1340    let auth_status = if jwt_config.is_some() {
1341        "JWT enabled"
1342    } else {
1343        "DISABLED (no JWT keys configured)"
1344    };
1345
1346    // Build JWKS for the endpoint
1347    let jwks_json = jwt_config.as_ref().and_then(|config| {
1348        config
1349            .public_key_pem
1350            .as_ref()
1351            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1352    });
1353
1354    let state = Arc::new(ProxyState {
1355        registry,
1356        skill_registry,
1357        keyring,
1358        jwt_config,
1359        jwks_json,
1360        auth_cache: AuthCache::new(),
1361    });
1362
1363    let app = build_router(state);
1364
1365    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1366        format!("{bind}:{port}").parse()?
1367    } else {
1368        SocketAddr::from(([127, 0, 0, 1], port))
1369    };
1370
1371    tracing::info!(
1372        version = env!("CARGO_PKG_VERSION"),
1373        %addr,
1374        auth = auth_status,
1375        ati_dir = %ati_dir.display(),
1376        tools = tool_count,
1377        providers = provider_count,
1378        mcp = mcp_count,
1379        openapi = openapi_count,
1380        skills = skill_count,
1381        keyring = keyring_source,
1382        "ATI proxy server starting"
1383    );
1384    for (name, transport) in &mcp_providers {
1385        tracing::info!(provider = %name, transport = %transport, "MCP provider");
1386    }
1387    for name in &openapi_providers {
1388        tracing::info!(provider = %name, "OpenAPI provider");
1389    }
1390
1391    let listener = tokio::net::TcpListener::bind(addr).await?;
1392    axum::serve(listener, app).await?;
1393
1394    Ok(())
1395}
1396
1397/// Write an audit entry from the proxy server. Failures are silently ignored.
1398fn write_proxy_audit(
1399    call_req: &CallRequest,
1400    agent_sub: &str,
1401    duration: std::time::Duration,
1402    error: Option<&str>,
1403) {
1404    let entry = crate::core::audit::AuditEntry {
1405        ts: chrono::Utc::now().to_rfc3339(),
1406        tool: call_req.tool_name.clone(),
1407        args: crate::core::audit::sanitize_args(&call_req.args),
1408        status: if error.is_some() {
1409            crate::core::audit::AuditStatus::Error
1410        } else {
1411            crate::core::audit::AuditStatus::Ok
1412        },
1413        duration_ms: duration.as_millis() as u64,
1414        agent_sub: agent_sub.to_string(),
1415        error: error.map(|s| s.to_string()),
1416        exit_code: None,
1417    };
1418    let _ = crate::core::audit::append(&entry);
1419}
1420
1421// --- Helpers ---
1422
1423const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1424
1425## Available Tools
1426{tools}
1427
1428{skills_section}
1429
1430Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1431
1432- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1433- If multiple steps are needed, walk through them briefly in order
1434- Mention important gotchas or parameter choices that matter
1435- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1436
1437Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1438
1439fn build_tool_context(
1440    tools: &[(
1441        &crate::core::manifest::Provider,
1442        &crate::core::manifest::Tool,
1443    )],
1444) -> String {
1445    let mut summaries = Vec::new();
1446    for (provider, tool) in tools {
1447        let mut summary = if let Some(cat) = &provider.category {
1448            format!(
1449                "- **{}** (provider: {}, category: {}): {}",
1450                tool.name, provider.name, cat, tool.description
1451            )
1452        } else {
1453            format!(
1454                "- **{}** (provider: {}): {}",
1455                tool.name, provider.name, tool.description
1456            )
1457        };
1458        if !tool.tags.is_empty() {
1459            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
1460        }
1461        // CLI tools: show passthrough usage
1462        if provider.is_cli() && tool.input_schema.is_none() {
1463            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1464            summary.push_str(&format!(
1465                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
1466                tool.name, cmd
1467            ));
1468        } else if let Some(schema) = &tool.input_schema {
1469            if let Some(props) = schema.get("properties") {
1470                if let Some(obj) = props.as_object() {
1471                    let params: Vec<String> = obj
1472                        .iter()
1473                        .filter(|(_, v)| {
1474                            v.get("x-ati-param-location").is_none()
1475                                || v.get("description").is_some()
1476                        })
1477                        .map(|(k, v)| {
1478                            let type_str =
1479                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1480                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1481                            format!("    --{k} ({type_str}): {desc}")
1482                        })
1483                        .collect();
1484                    if !params.is_empty() {
1485                        summary.push_str("\n  Parameters:\n");
1486                        summary.push_str(&params.join("\n"));
1487                    }
1488                }
1489            }
1490        }
1491        summaries.push(summary);
1492    }
1493    summaries.join("\n\n")
1494}
1495
1496/// Build a scoped system prompt for a specific tool or provider.
1497///
1498/// Returns None if the scope_name doesn't match any tool or provider.
1499fn build_scoped_prompt(
1500    scope_name: &str,
1501    registry: &ManifestRegistry,
1502    skills_section: &str,
1503) -> Option<String> {
1504    // Check if scope_name is a tool
1505    if let Some((provider, tool)) = registry.get_tool(scope_name) {
1506        let mut details = format!(
1507            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1508            tool.name, provider.name, provider.handler, tool.description
1509        );
1510        if let Some(cat) = &provider.category {
1511            details.push_str(&format!("**Category**: {}\n", cat));
1512        }
1513        if provider.is_cli() {
1514            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1515            details.push_str(&format!(
1516                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
1517                tool.name, cmd
1518            ));
1519        } else if let Some(schema) = &tool.input_schema {
1520            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1521                let required: Vec<String> = schema
1522                    .get("required")
1523                    .and_then(|r| r.as_array())
1524                    .map(|arr| {
1525                        arr.iter()
1526                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
1527                            .collect()
1528                    })
1529                    .unwrap_or_default();
1530                details.push_str("\n**Parameters**:\n");
1531                for (key, val) in props {
1532                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1533                    let desc = val
1534                        .get("description")
1535                        .and_then(|d| d.as_str())
1536                        .unwrap_or("");
1537                    let req = if required.contains(key) {
1538                        " **(required)**"
1539                    } else {
1540                        ""
1541                    };
1542                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1543                }
1544            }
1545        }
1546
1547        let prompt = format!(
1548            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1549            ## Tool Details\n{}\n\n{}\n\n\
1550            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1551            tool.name, details, skills_section
1552        );
1553        return Some(prompt);
1554    }
1555
1556    // Check if scope_name is a provider
1557    if registry.has_provider(scope_name) {
1558        let tools = registry.tools_by_provider(scope_name);
1559        if tools.is_empty() {
1560            return None;
1561        }
1562        let tools_context = build_tool_context(&tools);
1563        let prompt = format!(
1564            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1565            ## Tools in provider `{}`\n{}\n\n{}\n\n\
1566            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1567            scope_name, scope_name, tools_context, skills_section
1568        );
1569        return Some(prompt);
1570    }
1571
1572    None
1573}