1use axum::{
8 body::Body,
9 extract::{Extension, Query, State},
10 http::{Request as HttpRequest, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14 Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::{ManifestRegistry, Provider, Tool};
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::skillati::{RemoteSkillMeta, SkillAtiClient, SkillAtiError};
33use crate::core::xai;
34
35pub struct ProxyState {
37 pub registry: ManifestRegistry,
38 pub skill_registry: SkillRegistry,
39 pub keyring: Keyring,
40 pub jwt_config: Option<JwtConfig>,
42 pub jwks_json: Option<Value>,
44 pub auth_cache: AuthCache,
46}
47
48#[derive(Debug, Deserialize)]
51pub struct CallRequest {
52 pub tool_name: String,
53 #[serde(default = "default_args")]
57 pub args: Value,
58 #[serde(default)]
61 pub raw_args: Option<Vec<String>>,
62}
63
64fn default_args() -> Value {
65 Value::Object(serde_json::Map::new())
66}
67
68impl CallRequest {
69 fn args_as_map(&self) -> HashMap<String, Value> {
73 match &self.args {
74 Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
75 _ => HashMap::new(),
76 }
77 }
78
79 fn args_as_positional(&self) -> Vec<String> {
82 if let Some(ref raw) = self.raw_args {
84 return raw.clone();
85 }
86 match &self.args {
87 Value::Array(arr) => arr
89 .iter()
90 .map(|v| match v {
91 Value::String(s) => s.clone(),
92 other => other.to_string(),
93 })
94 .collect(),
95 Value::String(s) => s.split_whitespace().map(String::from).collect(),
97 Value::Object(map) => {
99 if let Some(Value::Array(pos)) = map.get("_positional") {
100 return pos
101 .iter()
102 .map(|v| match v {
103 Value::String(s) => s.clone(),
104 other => other.to_string(),
105 })
106 .collect();
107 }
108 let mut result = Vec::new();
110 for (k, v) in map {
111 result.push(format!("--{k}"));
112 match v {
113 Value::String(s) => result.push(s.clone()),
114 Value::Bool(true) => {} other => result.push(other.to_string()),
116 }
117 }
118 result
119 }
120 _ => Vec::new(),
121 }
122 }
123}
124
125#[derive(Debug, Serialize)]
126pub struct CallResponse {
127 pub result: Value,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 pub error: Option<String>,
130}
131
132#[derive(Debug, Deserialize)]
133pub struct HelpRequest {
134 pub query: String,
135 #[serde(default)]
136 pub tool: Option<String>,
137}
138
139#[derive(Debug, Serialize)]
140pub struct HelpResponse {
141 pub content: String,
142 #[serde(skip_serializing_if = "Option::is_none")]
143 pub error: Option<String>,
144}
145
146#[derive(Debug, Serialize)]
147pub struct HealthResponse {
148 pub status: String,
149 pub version: String,
150 pub tools: usize,
151 pub providers: usize,
152 pub skills: usize,
153 pub auth: String,
154}
155
156#[derive(Debug, Deserialize)]
159pub struct SkillsQuery {
160 #[serde(default)]
161 pub category: Option<String>,
162 #[serde(default)]
163 pub provider: Option<String>,
164 #[serde(default)]
165 pub tool: Option<String>,
166 #[serde(default)]
167 pub search: Option<String>,
168}
169
170#[derive(Debug, Deserialize)]
171pub struct SkillDetailQuery {
172 #[serde(default)]
173 pub meta: Option<bool>,
174 #[serde(default)]
175 pub refs: Option<bool>,
176}
177
178#[derive(Debug, Deserialize)]
179pub struct SkillResolveRequest {
180 pub scopes: Vec<String>,
181 #[serde(default)]
183 pub include_content: bool,
184}
185
186#[derive(Debug, Deserialize)]
187pub struct SkillBundleBatchRequest {
188 pub names: Vec<String>,
189}
190
191#[derive(Debug, Deserialize, Default)]
192pub struct SkillAtiCatalogQuery {
193 #[serde(default)]
194 pub search: Option<String>,
195}
196
197#[derive(Debug, Deserialize, Default)]
198pub struct SkillAtiResourcesQuery {
199 #[serde(default)]
200 pub prefix: Option<String>,
201}
202
203#[derive(Debug, Deserialize)]
204pub struct SkillAtiFileQuery {
205 pub path: String,
206}
207
208#[derive(Debug, Deserialize)]
211pub struct ToolsQuery {
212 #[serde(default)]
213 pub provider: Option<String>,
214 #[serde(default)]
215 pub search: Option<String>,
216}
217
218fn scopes_for_request(claims: Option<&TokenClaims>, state: &ProxyState) -> ScopeConfig {
221 match claims {
222 Some(claims) => ScopeConfig::from_jwt(claims),
223 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(),
224 None => ScopeConfig {
225 scopes: Vec::new(),
226 sub: String::new(),
227 expires_at: 0,
228 rate_config: None,
229 },
230 }
231}
232
233fn visible_tools_for_scopes<'a>(
234 state: &'a ProxyState,
235 scopes: &ScopeConfig,
236) -> Vec<(&'a Provider, &'a Tool)> {
237 crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
238}
239
240fn visible_skill_names(
241 state: &ProxyState,
242 scopes: &ScopeConfig,
243) -> std::collections::HashSet<String> {
244 skill::visible_skills(&state.skill_registry, &state.registry, scopes)
245 .into_iter()
246 .map(|skill| skill.name.clone())
247 .collect()
248}
249
250async fn handle_call(
251 State(state): State<Arc<ProxyState>>,
252 req: HttpRequest<Body>,
253) -> impl IntoResponse {
254 let claims = req.extensions().get::<TokenClaims>().cloned();
256
257 let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
259 Ok(b) => b,
260 Err(e) => {
261 return (
262 StatusCode::BAD_REQUEST,
263 Json(CallResponse {
264 result: Value::Null,
265 error: Some(format!("Failed to read request body: {e}")),
266 }),
267 );
268 }
269 };
270
271 let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
272 Ok(r) => r,
273 Err(e) => {
274 return (
275 StatusCode::UNPROCESSABLE_ENTITY,
276 Json(CallResponse {
277 result: Value::Null,
278 error: Some(format!("Invalid request: {e}")),
279 }),
280 );
281 }
282 };
283
284 tracing::debug!(
285 tool = %call_req.tool_name,
286 args = ?call_req.args,
287 "POST /call"
288 );
289
290 let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
293 Some(pt) => pt,
294 None => {
295 let mut resolved = None;
299 for (idx, _) in call_req.tool_name.match_indices('_') {
300 let candidate = format!(
301 "{}:{}",
302 &call_req.tool_name[..idx],
303 &call_req.tool_name[idx + 1..]
304 );
305 if let Some(pt) = state.registry.get_tool(&candidate) {
306 tracing::debug!(
307 original = %call_req.tool_name,
308 resolved = %candidate,
309 "resolved underscore tool name to colon format"
310 );
311 resolved = Some(pt);
312 break;
313 }
314 }
315
316 match resolved {
317 Some(pt) => pt,
318 None => {
319 return (
320 StatusCode::NOT_FOUND,
321 Json(CallResponse {
322 result: Value::Null,
323 error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
324 }),
325 );
326 }
327 }
328 }
329 };
330
331 if let Some(tool_scope) = &tool.scope {
333 let scopes = match &claims {
334 Some(c) => ScopeConfig::from_jwt(c),
335 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), None => {
337 return (
338 StatusCode::FORBIDDEN,
339 Json(CallResponse {
340 result: Value::Null,
341 error: Some("Authentication required — no JWT provided".into()),
342 }),
343 );
344 }
345 };
346
347 if !scopes.is_allowed(tool_scope) {
348 return (
349 StatusCode::FORBIDDEN,
350 Json(CallResponse {
351 result: Value::Null,
352 error: Some(format!(
353 "Access denied: '{}' is not in your scopes",
354 tool.name
355 )),
356 }),
357 );
358 }
359 }
360
361 {
363 let scopes = match &claims {
364 Some(c) => ScopeConfig::from_jwt(c),
365 None => ScopeConfig::unrestricted(),
366 };
367 if let Some(ref rate_config) = scopes.rate_config {
368 if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
369 return (
370 StatusCode::TOO_MANY_REQUESTS,
371 Json(CallResponse {
372 result: Value::Null,
373 error: Some(format!("{e}")),
374 }),
375 );
376 }
377 }
378 }
379
380 let gen_ctx = GenContext {
382 jwt_sub: claims
383 .as_ref()
384 .map(|c| c.sub.clone())
385 .unwrap_or_else(|| "dev".into()),
386 jwt_scope: claims
387 .as_ref()
388 .map(|c| c.scope.clone())
389 .unwrap_or_else(|| "*".into()),
390 tool_name: call_req.tool_name.clone(),
391 timestamp: crate::core::jwt::now_secs(),
392 };
393
394 let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
396 let start = std::time::Instant::now();
397
398 let response = match provider.handler.as_str() {
399 "mcp" => {
400 let args_map = call_req.args_as_map();
401 match mcp_client::execute_with_gen(
402 provider,
403 &call_req.tool_name,
404 &args_map,
405 &state.keyring,
406 Some(&gen_ctx),
407 Some(&state.auth_cache),
408 )
409 .await
410 {
411 Ok(result) => (
412 StatusCode::OK,
413 Json(CallResponse {
414 result,
415 error: None,
416 }),
417 ),
418 Err(e) => (
419 StatusCode::BAD_GATEWAY,
420 Json(CallResponse {
421 result: Value::Null,
422 error: Some(format!("MCP error: {e}")),
423 }),
424 ),
425 }
426 }
427 "cli" => {
428 let positional = call_req.args_as_positional();
429 match crate::core::cli_executor::execute_with_gen(
430 provider,
431 &positional,
432 &state.keyring,
433 Some(&gen_ctx),
434 Some(&state.auth_cache),
435 )
436 .await
437 {
438 Ok(result) => (
439 StatusCode::OK,
440 Json(CallResponse {
441 result,
442 error: None,
443 }),
444 ),
445 Err(e) => (
446 StatusCode::BAD_GATEWAY,
447 Json(CallResponse {
448 result: Value::Null,
449 error: Some(format!("CLI error: {e}")),
450 }),
451 ),
452 }
453 }
454 _ => {
455 let args_map = call_req.args_as_map();
456 let raw_response = match match provider.handler.as_str() {
457 "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
458 _ => {
459 http::execute_tool_with_gen(
460 provider,
461 tool,
462 &args_map,
463 &state.keyring,
464 Some(&gen_ctx),
465 Some(&state.auth_cache),
466 )
467 .await
468 }
469 } {
470 Ok(resp) => resp,
471 Err(e) => {
472 let duration = start.elapsed();
473 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
474 return (
475 StatusCode::BAD_GATEWAY,
476 Json(CallResponse {
477 result: Value::Null,
478 error: Some(format!("Upstream API error: {e}")),
479 }),
480 );
481 }
482 };
483
484 let processed = match response::process_response(&raw_response, tool.response.as_ref())
485 {
486 Ok(p) => p,
487 Err(e) => {
488 let duration = start.elapsed();
489 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
490 return (
491 StatusCode::INTERNAL_SERVER_ERROR,
492 Json(CallResponse {
493 result: raw_response,
494 error: Some(format!("Response processing error: {e}")),
495 }),
496 );
497 }
498 };
499
500 (
501 StatusCode::OK,
502 Json(CallResponse {
503 result: processed,
504 error: None,
505 }),
506 )
507 }
508 };
509
510 let duration = start.elapsed();
511 let error_msg = response.1.error.as_deref();
512 write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
513
514 response
515}
516
517async fn handle_help(
518 State(state): State<Arc<ProxyState>>,
519 claims: Option<Extension<TokenClaims>>,
520 Json(req): Json<HelpRequest>,
521) -> impl IntoResponse {
522 tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
523
524 let claims = claims.map(|Extension(claims)| claims);
525 let scopes = scopes_for_request(claims.as_ref(), &state);
526 if !scopes.help_enabled() {
527 return (
528 StatusCode::FORBIDDEN,
529 Json(HelpResponse {
530 content: String::new(),
531 error: Some("Help is not enabled in your scopes.".into()),
532 }),
533 );
534 }
535
536 let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
537 Some(pt) => pt,
538 None => {
539 return (
540 StatusCode::SERVICE_UNAVAILABLE,
541 Json(HelpResponse {
542 content: String::new(),
543 error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
544 }),
545 );
546 }
547 };
548
549 let api_key = match llm_provider
550 .auth_key_name
551 .as_deref()
552 .and_then(|k| state.keyring.get(k))
553 {
554 Some(key) => key.to_string(),
555 None => {
556 return (
557 StatusCode::SERVICE_UNAVAILABLE,
558 Json(HelpResponse {
559 content: String::new(),
560 error: Some("LLM API key not found in keyring".into()),
561 }),
562 );
563 }
564 };
565
566 let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
567 let local_skills_section = if resolved_skills.is_empty() {
568 String::new()
569 } else {
570 format!(
571 "## Available Skills (methodology guides)\n{}",
572 skill::build_skill_context(&resolved_skills)
573 )
574 };
575 let remote_query = req
576 .tool
577 .as_ref()
578 .map(|tool| format!("{tool} {}", req.query))
579 .unwrap_or_else(|| req.query.clone());
580 let remote_skills_section =
581 build_remote_skillati_section(&state.keyring, &remote_query, 12).await;
582 let skills_section = merge_help_skill_sections(&[local_skills_section, remote_skills_section]);
583
584 let visible_tools = visible_tools_for_scopes(&state, &scopes);
586 let system_prompt = if let Some(ref tool_name) = req.tool {
587 match build_scoped_prompt(tool_name, &visible_tools, &skills_section) {
589 Some(prompt) => prompt,
590 None => {
591 return (
592 StatusCode::FORBIDDEN,
593 Json(HelpResponse {
594 content: String::new(),
595 error: Some(format!(
596 "Scope '{tool_name}' is not visible in your current scopes."
597 )),
598 }),
599 );
600 }
601 }
602 } else {
603 let tools_context = build_tool_context(&visible_tools);
604 HELP_SYSTEM_PROMPT
605 .replace("{tools}", &tools_context)
606 .replace("{skills_section}", &skills_section)
607 };
608
609 let request_body = serde_json::json!({
610 "model": "zai-glm-4.7",
611 "messages": [
612 {"role": "system", "content": system_prompt},
613 {"role": "user", "content": req.query}
614 ],
615 "max_completion_tokens": 1536,
616 "temperature": 0.3
617 });
618
619 let client = reqwest::Client::new();
620 let url = format!(
621 "{}{}",
622 llm_provider.base_url.trim_end_matches('/'),
623 llm_tool.endpoint
624 );
625
626 let response = match client
627 .post(&url)
628 .bearer_auth(&api_key)
629 .json(&request_body)
630 .send()
631 .await
632 {
633 Ok(r) => r,
634 Err(e) => {
635 return (
636 StatusCode::BAD_GATEWAY,
637 Json(HelpResponse {
638 content: String::new(),
639 error: Some(format!("LLM request failed: {e}")),
640 }),
641 );
642 }
643 };
644
645 if !response.status().is_success() {
646 let status = response.status();
647 let body = response.text().await.unwrap_or_default();
648 return (
649 StatusCode::BAD_GATEWAY,
650 Json(HelpResponse {
651 content: String::new(),
652 error: Some(format!("LLM API error ({status}): {body}")),
653 }),
654 );
655 }
656
657 let body: Value = match response.json().await {
658 Ok(b) => b,
659 Err(e) => {
660 return (
661 StatusCode::INTERNAL_SERVER_ERROR,
662 Json(HelpResponse {
663 content: String::new(),
664 error: Some(format!("Failed to parse LLM response: {e}")),
665 }),
666 );
667 }
668 };
669
670 let content = body
671 .pointer("/choices/0/message/content")
672 .and_then(|c| c.as_str())
673 .unwrap_or("No response from LLM")
674 .to_string();
675
676 (
677 StatusCode::OK,
678 Json(HelpResponse {
679 content,
680 error: None,
681 }),
682 )
683}
684
685async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
686 let auth = if state.jwt_config.is_some() {
687 "jwt"
688 } else {
689 "disabled"
690 };
691
692 Json(HealthResponse {
693 status: "ok".into(),
694 version: env!("CARGO_PKG_VERSION").into(),
695 tools: state.registry.list_public_tools().len(),
696 providers: state.registry.list_providers().len(),
697 skills: state.skill_registry.skill_count(),
698 auth: auth.into(),
699 })
700}
701
702async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
704 match &state.jwks_json {
705 Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
706 None => (
707 StatusCode::NOT_FOUND,
708 Json(serde_json::json!({"error": "JWKS not configured"})),
709 ),
710 }
711}
712
713async fn handle_mcp(
718 State(state): State<Arc<ProxyState>>,
719 claims: Option<Extension<TokenClaims>>,
720 Json(msg): Json<Value>,
721) -> impl IntoResponse {
722 let claims = claims.map(|Extension(claims)| claims);
723 let scopes = scopes_for_request(claims.as_ref(), &state);
724 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
725 let id = msg.get("id").cloned();
726
727 tracing::debug!(%method, "POST /mcp");
728
729 match method {
730 "initialize" => {
731 let result = serde_json::json!({
732 "protocolVersion": "2025-03-26",
733 "capabilities": {
734 "tools": { "listChanged": false }
735 },
736 "serverInfo": {
737 "name": "ati-proxy",
738 "version": env!("CARGO_PKG_VERSION")
739 }
740 });
741 jsonrpc_success(id, result)
742 }
743
744 "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
745
746 "tools/list" => {
747 let visible_tools = visible_tools_for_scopes(&state, &scopes);
748 let mcp_tools: Vec<Value> = visible_tools
749 .iter()
750 .map(|(_provider, tool)| {
751 serde_json::json!({
752 "name": tool.name,
753 "description": tool.description,
754 "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
755 "type": "object",
756 "properties": {}
757 }))
758 })
759 })
760 .collect();
761
762 let result = serde_json::json!({
763 "tools": mcp_tools,
764 });
765 jsonrpc_success(id, result)
766 }
767
768 "tools/call" => {
769 let params = msg.get("params").cloned().unwrap_or(Value::Null);
770 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
771 let arguments: HashMap<String, Value> = params
772 .get("arguments")
773 .and_then(|a| serde_json::from_value(a.clone()).ok())
774 .unwrap_or_default();
775
776 if tool_name.is_empty() {
777 return jsonrpc_error(id, -32602, "Missing tool name in params.name");
778 }
779
780 let (provider, _tool) = match state.registry.get_tool(tool_name) {
781 Some(pt) => pt,
782 None => {
783 return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
784 }
785 };
786
787 if let Some(tool_scope) = &_tool.scope {
788 if !scopes.is_allowed(tool_scope) {
789 return jsonrpc_error(
790 id,
791 -32001,
792 &format!("Access denied: '{}' is not in your scopes", _tool.name),
793 );
794 }
795 }
796
797 tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
798
799 let mcp_gen_ctx = GenContext {
800 jwt_sub: claims
801 .as_ref()
802 .map(|claims| claims.sub.clone())
803 .unwrap_or_else(|| "dev".into()),
804 jwt_scope: claims
805 .as_ref()
806 .map(|claims| claims.scope.clone())
807 .unwrap_or_else(|| "*".into()),
808 tool_name: tool_name.to_string(),
809 timestamp: crate::core::jwt::now_secs(),
810 };
811
812 let result = if provider.is_mcp() {
813 mcp_client::execute_with_gen(
814 provider,
815 tool_name,
816 &arguments,
817 &state.keyring,
818 Some(&mcp_gen_ctx),
819 Some(&state.auth_cache),
820 )
821 .await
822 } else if provider.is_cli() {
823 let raw: Vec<String> = arguments
825 .iter()
826 .flat_map(|(k, v)| {
827 let val = match v {
828 Value::String(s) => s.clone(),
829 other => other.to_string(),
830 };
831 vec![format!("--{k}"), val]
832 })
833 .collect();
834 crate::core::cli_executor::execute_with_gen(
835 provider,
836 &raw,
837 &state.keyring,
838 Some(&mcp_gen_ctx),
839 Some(&state.auth_cache),
840 )
841 .await
842 .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
843 } else {
844 match match provider.handler.as_str() {
845 "xai" => {
846 xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
847 }
848 _ => {
849 http::execute_tool_with_gen(
850 provider,
851 _tool,
852 &arguments,
853 &state.keyring,
854 Some(&mcp_gen_ctx),
855 Some(&state.auth_cache),
856 )
857 .await
858 }
859 } {
860 Ok(val) => Ok(val),
861 Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
862 }
863 };
864
865 match result {
866 Ok(value) => {
867 let text = match &value {
868 Value::String(s) => s.clone(),
869 other => serde_json::to_string_pretty(other).unwrap_or_default(),
870 };
871 let mcp_result = serde_json::json!({
872 "content": [{"type": "text", "text": text}],
873 "isError": false,
874 });
875 jsonrpc_success(id, mcp_result)
876 }
877 Err(e) => {
878 let mcp_result = serde_json::json!({
879 "content": [{"type": "text", "text": format!("Error: {e}")}],
880 "isError": true,
881 });
882 jsonrpc_success(id, mcp_result)
883 }
884 }
885 }
886
887 _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
888 }
889}
890
891fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
892 (
893 StatusCode::OK,
894 Json(serde_json::json!({
895 "jsonrpc": "2.0",
896 "id": id,
897 "result": result,
898 })),
899 )
900}
901
902fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
903 (
904 StatusCode::OK,
905 Json(serde_json::json!({
906 "jsonrpc": "2.0",
907 "id": id,
908 "error": {
909 "code": code,
910 "message": message,
911 }
912 })),
913 )
914}
915
916async fn handle_tools_list(
922 State(state): State<Arc<ProxyState>>,
923 claims: Option<Extension<TokenClaims>>,
924 axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
925) -> impl IntoResponse {
926 tracing::debug!(
927 provider = ?query.provider,
928 search = ?query.search,
929 "GET /tools"
930 );
931
932 let claims = claims.map(|Extension(claims)| claims);
933 let scopes = scopes_for_request(claims.as_ref(), &state);
934 let all_tools = visible_tools_for_scopes(&state, &scopes);
935
936 let tools: Vec<Value> = all_tools
937 .iter()
938 .filter(|(provider, tool)| {
939 if let Some(ref p) = query.provider {
940 if provider.name != *p {
941 return false;
942 }
943 }
944 if let Some(ref q) = query.search {
945 let q = q.to_lowercase();
946 let name_match = tool.name.to_lowercase().contains(&q);
947 let desc_match = tool.description.to_lowercase().contains(&q);
948 let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
949 if !name_match && !desc_match && !tag_match {
950 return false;
951 }
952 }
953 true
954 })
955 .map(|(provider, tool)| {
956 serde_json::json!({
957 "name": tool.name,
958 "description": tool.description,
959 "provider": provider.name,
960 "method": format!("{:?}", tool.method),
961 "tags": tool.tags,
962 "input_schema": tool.input_schema,
963 })
964 })
965 .collect();
966
967 (StatusCode::OK, Json(Value::Array(tools)))
968}
969
970async fn handle_tool_info(
972 State(state): State<Arc<ProxyState>>,
973 claims: Option<Extension<TokenClaims>>,
974 axum::extract::Path(name): axum::extract::Path<String>,
975) -> impl IntoResponse {
976 tracing::debug!(tool = %name, "GET /tools/:name");
977
978 let claims = claims.map(|Extension(claims)| claims);
979 let scopes = scopes_for_request(claims.as_ref(), &state);
980
981 match state
982 .registry
983 .get_tool(&name)
984 .filter(|(_, tool)| match &tool.scope {
985 Some(scope) => scopes.is_allowed(scope),
986 None => true,
987 }) {
988 Some((provider, tool)) => (
989 StatusCode::OK,
990 Json(serde_json::json!({
991 "name": tool.name,
992 "description": tool.description,
993 "provider": provider.name,
994 "method": format!("{:?}", tool.method),
995 "endpoint": tool.endpoint,
996 "tags": tool.tags,
997 "hint": tool.hint,
998 "input_schema": tool.input_schema,
999 "scope": tool.scope,
1000 })),
1001 ),
1002 None => (
1003 StatusCode::NOT_FOUND,
1004 Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
1005 ),
1006 }
1007}
1008
1009async fn handle_skills_list(
1014 State(state): State<Arc<ProxyState>>,
1015 claims: Option<Extension<TokenClaims>>,
1016 axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
1017) -> impl IntoResponse {
1018 tracing::debug!(
1019 category = ?query.category,
1020 provider = ?query.provider,
1021 tool = ?query.tool,
1022 search = ?query.search,
1023 "GET /skills"
1024 );
1025
1026 let claims = claims.map(|Extension(claims)| claims);
1027 let scopes = scopes_for_request(claims.as_ref(), &state);
1028 let visible_names = visible_skill_names(&state, &scopes);
1029
1030 let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
1031 state
1032 .skill_registry
1033 .search(search_query)
1034 .into_iter()
1035 .filter(|skill| visible_names.contains(&skill.name))
1036 .collect()
1037 } else if let Some(cat) = &query.category {
1038 state
1039 .skill_registry
1040 .skills_for_category(cat)
1041 .into_iter()
1042 .filter(|skill| visible_names.contains(&skill.name))
1043 .collect()
1044 } else if let Some(prov) = &query.provider {
1045 state
1046 .skill_registry
1047 .skills_for_provider(prov)
1048 .into_iter()
1049 .filter(|skill| visible_names.contains(&skill.name))
1050 .collect()
1051 } else if let Some(t) = &query.tool {
1052 state
1053 .skill_registry
1054 .skills_for_tool(t)
1055 .into_iter()
1056 .filter(|skill| visible_names.contains(&skill.name))
1057 .collect()
1058 } else {
1059 state
1060 .skill_registry
1061 .list_skills()
1062 .iter()
1063 .filter(|skill| visible_names.contains(&skill.name))
1064 .collect()
1065 };
1066
1067 let json: Vec<Value> = skills
1068 .iter()
1069 .map(|s| {
1070 serde_json::json!({
1071 "name": s.name,
1072 "version": s.version,
1073 "description": s.description,
1074 "tools": s.tools,
1075 "providers": s.providers,
1076 "categories": s.categories,
1077 "hint": s.hint,
1078 })
1079 })
1080 .collect();
1081
1082 (StatusCode::OK, Json(Value::Array(json)))
1083}
1084
1085async fn handle_skill_detail(
1086 State(state): State<Arc<ProxyState>>,
1087 claims: Option<Extension<TokenClaims>>,
1088 axum::extract::Path(name): axum::extract::Path<String>,
1089 axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
1090) -> impl IntoResponse {
1091 tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
1092
1093 let claims = claims.map(|Extension(claims)| claims);
1094 let scopes = scopes_for_request(claims.as_ref(), &state);
1095 let visible_names = visible_skill_names(&state, &scopes);
1096
1097 let skill_meta = match state
1098 .skill_registry
1099 .get_skill(&name)
1100 .filter(|skill| visible_names.contains(&skill.name))
1101 {
1102 Some(s) => s,
1103 None => {
1104 return (
1105 StatusCode::NOT_FOUND,
1106 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1107 );
1108 }
1109 };
1110
1111 if query.meta.unwrap_or(false) {
1112 return (
1113 StatusCode::OK,
1114 Json(serde_json::json!({
1115 "name": skill_meta.name,
1116 "version": skill_meta.version,
1117 "description": skill_meta.description,
1118 "author": skill_meta.author,
1119 "tools": skill_meta.tools,
1120 "providers": skill_meta.providers,
1121 "categories": skill_meta.categories,
1122 "keywords": skill_meta.keywords,
1123 "hint": skill_meta.hint,
1124 "depends_on": skill_meta.depends_on,
1125 "suggests": skill_meta.suggests,
1126 "license": skill_meta.license,
1127 "compatibility": skill_meta.compatibility,
1128 "allowed_tools": skill_meta.allowed_tools,
1129 "format": skill_meta.format,
1130 })),
1131 );
1132 }
1133
1134 let content = match state.skill_registry.read_content(&name) {
1135 Ok(c) => c,
1136 Err(e) => {
1137 return (
1138 StatusCode::INTERNAL_SERVER_ERROR,
1139 Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
1140 );
1141 }
1142 };
1143
1144 let mut response = serde_json::json!({
1145 "name": skill_meta.name,
1146 "version": skill_meta.version,
1147 "description": skill_meta.description,
1148 "content": content,
1149 });
1150
1151 if query.refs.unwrap_or(false) {
1152 if let Ok(refs) = state.skill_registry.list_references(&name) {
1153 response["references"] = serde_json::json!(refs);
1154 }
1155 }
1156
1157 (StatusCode::OK, Json(response))
1158}
1159
1160async fn handle_skill_bundle(
1164 State(state): State<Arc<ProxyState>>,
1165 claims: Option<Extension<TokenClaims>>,
1166 axum::extract::Path(name): axum::extract::Path<String>,
1167) -> impl IntoResponse {
1168 tracing::debug!(skill = %name, "GET /skills/:name/bundle");
1169
1170 let claims = claims.map(|Extension(claims)| claims);
1171 let scopes = scopes_for_request(claims.as_ref(), &state);
1172 let visible_names = visible_skill_names(&state, &scopes);
1173 if !visible_names.contains(&name) {
1174 return (
1175 StatusCode::NOT_FOUND,
1176 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1177 );
1178 }
1179
1180 let files = match state.skill_registry.bundle_files(&name) {
1181 Ok(f) => f,
1182 Err(_) => {
1183 return (
1184 StatusCode::NOT_FOUND,
1185 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1186 );
1187 }
1188 };
1189
1190 let mut file_map = serde_json::Map::new();
1192 for (path, data) in &files {
1193 match std::str::from_utf8(data) {
1194 Ok(text) => {
1195 file_map.insert(path.clone(), Value::String(text.to_string()));
1196 }
1197 Err(_) => {
1198 use base64::Engine;
1200 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1201 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1202 }
1203 }
1204 }
1205
1206 (
1207 StatusCode::OK,
1208 Json(serde_json::json!({
1209 "name": name,
1210 "files": file_map,
1211 })),
1212 )
1213}
1214
1215async fn handle_skills_bundle_batch(
1219 State(state): State<Arc<ProxyState>>,
1220 claims: Option<Extension<TokenClaims>>,
1221 Json(req): Json<SkillBundleBatchRequest>,
1222) -> impl IntoResponse {
1223 const MAX_BATCH: usize = 50;
1224 if req.names.len() > MAX_BATCH {
1225 return (
1226 StatusCode::BAD_REQUEST,
1227 Json(
1228 serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
1229 ),
1230 );
1231 }
1232
1233 tracing::debug!(names = ?req.names, "POST /skills/bundle");
1234
1235 let claims = claims.map(|Extension(claims)| claims);
1236 let scopes = scopes_for_request(claims.as_ref(), &state);
1237 let visible_names = visible_skill_names(&state, &scopes);
1238
1239 let mut result = serde_json::Map::new();
1240 let mut missing: Vec<String> = Vec::new();
1241
1242 for name in &req.names {
1243 if !visible_names.contains(name) {
1244 missing.push(name.clone());
1245 continue;
1246 }
1247 let files = match state.skill_registry.bundle_files(name) {
1248 Ok(f) => f,
1249 Err(_) => {
1250 missing.push(name.clone());
1251 continue;
1252 }
1253 };
1254
1255 let mut file_map = serde_json::Map::new();
1256 for (path, data) in &files {
1257 match std::str::from_utf8(data) {
1258 Ok(text) => {
1259 file_map.insert(path.clone(), Value::String(text.to_string()));
1260 }
1261 Err(_) => {
1262 use base64::Engine;
1263 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1264 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1265 }
1266 }
1267 }
1268
1269 result.insert(name.clone(), serde_json::json!({ "files": file_map }));
1270 }
1271
1272 (
1273 StatusCode::OK,
1274 Json(serde_json::json!({ "skills": result, "missing": missing })),
1275 )
1276}
1277
1278async fn handle_skills_resolve(
1279 State(state): State<Arc<ProxyState>>,
1280 claims: Option<Extension<TokenClaims>>,
1281 Json(req): Json<SkillResolveRequest>,
1282) -> impl IntoResponse {
1283 tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1284
1285 let include_content = req.include_content;
1286 let request_scopes = ScopeConfig {
1287 scopes: req.scopes,
1288 sub: String::new(),
1289 expires_at: 0,
1290 rate_config: None,
1291 };
1292 let claims = claims.map(|Extension(claims)| claims);
1293 let caller_scopes = scopes_for_request(claims.as_ref(), &state);
1294 let visible_names = visible_skill_names(&state, &caller_scopes);
1295
1296 let resolved: Vec<&skill::SkillMeta> =
1297 skill::resolve_skills(&state.skill_registry, &state.registry, &request_scopes)
1298 .into_iter()
1299 .filter(|skill| visible_names.contains(&skill.name))
1300 .collect();
1301
1302 let json: Vec<Value> = resolved
1303 .iter()
1304 .map(|s| {
1305 let mut entry = serde_json::json!({
1306 "name": s.name,
1307 "version": s.version,
1308 "description": s.description,
1309 "tools": s.tools,
1310 "providers": s.providers,
1311 "categories": s.categories,
1312 });
1313 if include_content {
1314 if let Ok(content) = state.skill_registry.read_content(&s.name) {
1315 entry["content"] = Value::String(content);
1316 }
1317 }
1318 entry
1319 })
1320 .collect();
1321
1322 (StatusCode::OK, Json(Value::Array(json)))
1323}
1324
1325fn skillati_client(keyring: &Keyring) -> Result<SkillAtiClient, SkillAtiError> {
1326 match SkillAtiClient::from_env(keyring)? {
1327 Some(client) => Ok(client),
1328 None => Err(SkillAtiError::NotConfigured),
1329 }
1330}
1331
1332async fn handle_skillati_catalog(
1333 State(state): State<Arc<ProxyState>>,
1334 claims: Option<Extension<TokenClaims>>,
1335 Query(query): Query<SkillAtiCatalogQuery>,
1336) -> impl IntoResponse {
1337 tracing::debug!(search = ?query.search, "GET /skillati/catalog");
1338
1339 let client = match skillati_client(&state.keyring) {
1340 Ok(client) => client,
1341 Err(err) => return skillati_error_response(err),
1342 };
1343
1344 let claims = claims.map(|Extension(c)| c);
1345 let scopes = scopes_for_request(claims.as_ref(), &state);
1346 let visible_names = visible_skill_names(&state, &scopes);
1347
1348 match client.catalog().await {
1349 Ok(catalog) => {
1350 let mut skills: Vec<_> = catalog
1351 .into_iter()
1352 .filter(|s| visible_names.contains(&s.name))
1353 .collect();
1354 if let Some(search) = query.search.as_deref() {
1355 skills = SkillAtiClient::filter_catalog(&skills, search, 25);
1356 }
1357 (
1358 StatusCode::OK,
1359 Json(serde_json::json!({
1360 "skills": skills,
1361 })),
1362 )
1363 }
1364 Err(err) => skillati_error_response(err),
1365 }
1366}
1367
1368async fn handle_skillati_read(
1369 State(state): State<Arc<ProxyState>>,
1370 claims: Option<Extension<TokenClaims>>,
1371 axum::extract::Path(name): axum::extract::Path<String>,
1372) -> impl IntoResponse {
1373 tracing::debug!(%name, "GET /skillati/:name");
1374
1375 let client = match skillati_client(&state.keyring) {
1376 Ok(client) => client,
1377 Err(err) => return skillati_error_response(err),
1378 };
1379
1380 let claims = claims.map(|Extension(c)| c);
1381 let scopes = scopes_for_request(claims.as_ref(), &state);
1382 let visible_names = visible_skill_names(&state, &scopes);
1383 if !visible_names.contains(&name) {
1384 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1385 }
1386
1387 match client.read_skill(&name).await {
1388 Ok(activation) => (StatusCode::OK, Json(serde_json::json!(activation))),
1389 Err(err) => skillati_error_response(err),
1390 }
1391}
1392
1393async fn handle_skillati_resources(
1394 State(state): State<Arc<ProxyState>>,
1395 claims: Option<Extension<TokenClaims>>,
1396 axum::extract::Path(name): axum::extract::Path<String>,
1397 Query(query): Query<SkillAtiResourcesQuery>,
1398) -> impl IntoResponse {
1399 tracing::debug!(%name, prefix = ?query.prefix, "GET /skillati/:name/resources");
1400
1401 let client = match skillati_client(&state.keyring) {
1402 Ok(client) => client,
1403 Err(err) => return skillati_error_response(err),
1404 };
1405
1406 let claims = claims.map(|Extension(c)| c);
1407 let scopes = scopes_for_request(claims.as_ref(), &state);
1408 let visible_names = visible_skill_names(&state, &scopes);
1409 if !visible_names.contains(&name) {
1410 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1411 }
1412
1413 match client.list_resources(&name, query.prefix.as_deref()).await {
1414 Ok(resources) => (
1415 StatusCode::OK,
1416 Json(serde_json::json!({
1417 "name": name,
1418 "prefix": query.prefix,
1419 "resources": resources,
1420 })),
1421 ),
1422 Err(err) => skillati_error_response(err),
1423 }
1424}
1425
1426async fn handle_skillati_file(
1427 State(state): State<Arc<ProxyState>>,
1428 claims: Option<Extension<TokenClaims>>,
1429 axum::extract::Path(name): axum::extract::Path<String>,
1430 Query(query): Query<SkillAtiFileQuery>,
1431) -> impl IntoResponse {
1432 tracing::debug!(%name, path = %query.path, "GET /skillati/:name/file");
1433
1434 let client = match skillati_client(&state.keyring) {
1435 Ok(client) => client,
1436 Err(err) => return skillati_error_response(err),
1437 };
1438
1439 let claims = claims.map(|Extension(c)| c);
1440 let scopes = scopes_for_request(claims.as_ref(), &state);
1441 let visible_names = visible_skill_names(&state, &scopes);
1442 if !visible_names.contains(&name) {
1443 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1444 }
1445
1446 match client.read_path(&name, &query.path).await {
1447 Ok(file) => (StatusCode::OK, Json(serde_json::json!(file))),
1448 Err(err) => skillati_error_response(err),
1449 }
1450}
1451
1452async fn handle_skillati_refs(
1453 State(state): State<Arc<ProxyState>>,
1454 claims: Option<Extension<TokenClaims>>,
1455 axum::extract::Path(name): axum::extract::Path<String>,
1456) -> impl IntoResponse {
1457 tracing::debug!(%name, "GET /skillati/:name/refs");
1458
1459 let client = match skillati_client(&state.keyring) {
1460 Ok(client) => client,
1461 Err(err) => return skillati_error_response(err),
1462 };
1463
1464 let claims = claims.map(|Extension(c)| c);
1465 let scopes = scopes_for_request(claims.as_ref(), &state);
1466 let visible_names = visible_skill_names(&state, &scopes);
1467 if !visible_names.contains(&name) {
1468 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1469 }
1470
1471 match client.list_references(&name).await {
1472 Ok(references) => (
1473 StatusCode::OK,
1474 Json(serde_json::json!({
1475 "name": name,
1476 "references": references,
1477 })),
1478 ),
1479 Err(err) => skillati_error_response(err),
1480 }
1481}
1482
1483async fn handle_skillati_ref(
1484 State(state): State<Arc<ProxyState>>,
1485 claims: Option<Extension<TokenClaims>>,
1486 axum::extract::Path((name, reference)): axum::extract::Path<(String, String)>,
1487) -> impl IntoResponse {
1488 tracing::debug!(%name, %reference, "GET /skillati/:name/ref/:reference");
1489
1490 let client = match skillati_client(&state.keyring) {
1491 Ok(client) => client,
1492 Err(err) => return skillati_error_response(err),
1493 };
1494
1495 let claims = claims.map(|Extension(c)| c);
1496 let scopes = scopes_for_request(claims.as_ref(), &state);
1497 let visible_names = visible_skill_names(&state, &scopes);
1498 if !visible_names.contains(&name) {
1499 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1500 }
1501
1502 match client.read_reference(&name, &reference).await {
1503 Ok(content) => (
1504 StatusCode::OK,
1505 Json(serde_json::json!({
1506 "name": name,
1507 "reference": reference,
1508 "content": content,
1509 })),
1510 ),
1511 Err(err) => skillati_error_response(err),
1512 }
1513}
1514
1515fn skillati_error_response(err: SkillAtiError) -> (StatusCode, Json<Value>) {
1516 let status = match &err {
1517 SkillAtiError::NotConfigured
1518 | SkillAtiError::UnsupportedRegistry(_)
1519 | SkillAtiError::MissingCredentials(_)
1520 | SkillAtiError::ProxyUrlRequired => StatusCode::SERVICE_UNAVAILABLE,
1521 SkillAtiError::SkillNotFound(_) | SkillAtiError::PathNotFound { .. } => {
1522 StatusCode::NOT_FOUND
1523 }
1524 SkillAtiError::InvalidPath(_) => StatusCode::BAD_REQUEST,
1525 SkillAtiError::Gcs(_)
1526 | SkillAtiError::ProxyRequest(_)
1527 | SkillAtiError::ProxyResponse(_) => StatusCode::BAD_GATEWAY,
1528 };
1529
1530 (
1531 status,
1532 Json(serde_json::json!({
1533 "error": err.to_string(),
1534 })),
1535 )
1536}
1537
1538async fn auth_middleware(
1546 State(state): State<Arc<ProxyState>>,
1547 mut req: HttpRequest<Body>,
1548 next: Next,
1549) -> Result<Response, StatusCode> {
1550 let path = req.uri().path();
1551
1552 if path == "/health" || path == "/.well-known/jwks.json" {
1554 return Ok(next.run(req).await);
1555 }
1556
1557 let jwt_config = match &state.jwt_config {
1559 Some(c) => c,
1560 None => return Ok(next.run(req).await),
1561 };
1562
1563 let auth_header = req
1565 .headers()
1566 .get("authorization")
1567 .and_then(|v| v.to_str().ok());
1568
1569 let token = match auth_header {
1570 Some(header) if header.starts_with("Bearer ") => &header[7..],
1571 _ => return Err(StatusCode::UNAUTHORIZED),
1572 };
1573
1574 match jwt::validate(token, jwt_config) {
1576 Ok(claims) => {
1577 tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1578 req.extensions_mut().insert(claims);
1579 Ok(next.run(req).await)
1580 }
1581 Err(e) => {
1582 tracing::debug!(error = %e, "JWT validation failed");
1583 Err(StatusCode::UNAUTHORIZED)
1584 }
1585 }
1586}
1587
1588pub fn build_router(state: Arc<ProxyState>) -> Router {
1592 Router::new()
1593 .route("/call", post(handle_call))
1594 .route("/help", post(handle_help))
1595 .route("/mcp", post(handle_mcp))
1596 .route("/tools", get(handle_tools_list))
1597 .route("/tools/{name}", get(handle_tool_info))
1598 .route("/skills", get(handle_skills_list))
1599 .route("/skills/resolve", post(handle_skills_resolve))
1600 .route("/skills/bundle", post(handle_skills_bundle_batch))
1601 .route("/skills/{name}", get(handle_skill_detail))
1602 .route("/skills/{name}/bundle", get(handle_skill_bundle))
1603 .route("/skillati/catalog", get(handle_skillati_catalog))
1604 .route("/skillati/{name}", get(handle_skillati_read))
1605 .route("/skillati/{name}/resources", get(handle_skillati_resources))
1606 .route("/skillati/{name}/file", get(handle_skillati_file))
1607 .route("/skillati/{name}/refs", get(handle_skillati_refs))
1608 .route("/skillati/{name}/ref/{reference}", get(handle_skillati_ref))
1609 .route("/health", get(handle_health))
1610 .route("/.well-known/jwks.json", get(handle_jwks))
1611 .layer(middleware::from_fn_with_state(
1612 state.clone(),
1613 auth_middleware,
1614 ))
1615 .with_state(state)
1616}
1617
1618pub async fn run(
1622 port: u16,
1623 bind_addr: Option<String>,
1624 ati_dir: PathBuf,
1625 _verbose: bool,
1626 env_keys: bool,
1627) -> Result<(), Box<dyn std::error::Error>> {
1628 let manifests_dir = ati_dir.join("manifests");
1630 let mut registry = ManifestRegistry::load(&manifests_dir)?;
1631 let provider_count = registry.list_providers().len();
1632
1633 let keyring_source;
1635 let keyring = if env_keys {
1636 let kr = Keyring::from_env();
1638 let key_names = kr.key_names();
1639 tracing::info!(
1640 count = key_names.len(),
1641 "loaded API keys from ATI_KEY_* env vars"
1642 );
1643 for name in &key_names {
1644 tracing::debug!(key = %name, "env key loaded");
1645 }
1646 keyring_source = "env-vars (ATI_KEY_*)";
1647 kr
1648 } else {
1649 let keyring_path = ati_dir.join("keyring.enc");
1651 if keyring_path.exists() {
1652 if let Ok(kr) = Keyring::load(&keyring_path) {
1653 keyring_source = "keyring.enc (sealed key)";
1654 kr
1655 } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1656 keyring_source = "keyring.enc (persistent key)";
1657 kr
1658 } else {
1659 tracing::warn!("keyring.enc exists but could not be decrypted");
1660 keyring_source = "empty (decryption failed)";
1661 Keyring::empty()
1662 }
1663 } else {
1664 let creds_path = ati_dir.join("credentials");
1665 if creds_path.exists() {
1666 match Keyring::load_credentials(&creds_path) {
1667 Ok(kr) => {
1668 keyring_source = "credentials (plaintext)";
1669 kr
1670 }
1671 Err(e) => {
1672 tracing::warn!(error = %e, "failed to load credentials");
1673 keyring_source = "empty (credentials error)";
1674 Keyring::empty()
1675 }
1676 }
1677 } else {
1678 tracing::warn!("no keyring.enc or credentials found — running without API keys");
1679 tracing::warn!("tools requiring authentication will fail");
1680 keyring_source = "empty (no auth)";
1681 Keyring::empty()
1682 }
1683 }
1684 };
1685
1686 mcp_client::discover_all_mcp_tools(&mut registry, &keyring).await;
1689
1690 let tool_count = registry.list_public_tools().len();
1691
1692 let mcp_providers: Vec<(String, String)> = registry
1694 .list_mcp_providers()
1695 .iter()
1696 .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1697 .collect();
1698 let mcp_count = mcp_providers.len();
1699 let openapi_providers: Vec<String> = registry
1700 .list_openapi_providers()
1701 .iter()
1702 .map(|p| p.name.clone())
1703 .collect();
1704 let openapi_count = openapi_providers.len();
1705
1706 let skills_dir = ati_dir.join("skills");
1708 let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1709 tracing::warn!(error = %e, "failed to load skills");
1710 SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1711 });
1712
1713 if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1714 if registry_url.strip_prefix("gcs://").is_some() {
1715 tracing::info!(
1716 registry = %registry_url,
1717 "SkillATI remote registry configured for lazy reads"
1718 );
1719 } else {
1720 tracing::warn!(url = %registry_url, "SkillATI only supports gcs:// registries");
1721 }
1722 }
1723
1724 let skill_count = skill_registry.skill_count();
1725
1726 let jwt_config = match jwt::config_from_env() {
1728 Ok(config) => config,
1729 Err(e) => {
1730 tracing::warn!(error = %e, "JWT config error");
1731 None
1732 }
1733 };
1734
1735 let auth_status = if jwt_config.is_some() {
1736 "JWT enabled"
1737 } else {
1738 "DISABLED (no JWT keys configured)"
1739 };
1740
1741 let jwks_json = jwt_config.as_ref().and_then(|config| {
1743 config
1744 .public_key_pem
1745 .as_ref()
1746 .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1747 });
1748
1749 let state = Arc::new(ProxyState {
1750 registry,
1751 skill_registry,
1752 keyring,
1753 jwt_config,
1754 jwks_json,
1755 auth_cache: AuthCache::new(),
1756 });
1757
1758 let app = build_router(state);
1759
1760 let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1761 format!("{bind}:{port}").parse()?
1762 } else {
1763 SocketAddr::from(([127, 0, 0, 1], port))
1764 };
1765
1766 tracing::info!(
1767 version = env!("CARGO_PKG_VERSION"),
1768 %addr,
1769 auth = auth_status,
1770 ati_dir = %ati_dir.display(),
1771 tools = tool_count,
1772 providers = provider_count,
1773 mcp = mcp_count,
1774 openapi = openapi_count,
1775 skills = skill_count,
1776 keyring = keyring_source,
1777 "ATI proxy server starting"
1778 );
1779 for (name, transport) in &mcp_providers {
1780 tracing::info!(provider = %name, transport = %transport, "MCP provider");
1781 }
1782 for name in &openapi_providers {
1783 tracing::info!(provider = %name, "OpenAPI provider");
1784 }
1785
1786 let listener = tokio::net::TcpListener::bind(addr).await?;
1787 axum::serve(listener, app).await?;
1788
1789 Ok(())
1790}
1791
1792fn write_proxy_audit(
1794 call_req: &CallRequest,
1795 agent_sub: &str,
1796 duration: std::time::Duration,
1797 error: Option<&str>,
1798) {
1799 let entry = crate::core::audit::AuditEntry {
1800 ts: chrono::Utc::now().to_rfc3339(),
1801 tool: call_req.tool_name.clone(),
1802 args: crate::core::audit::sanitize_args(&call_req.args),
1803 status: if error.is_some() {
1804 crate::core::audit::AuditStatus::Error
1805 } else {
1806 crate::core::audit::AuditStatus::Ok
1807 },
1808 duration_ms: duration.as_millis() as u64,
1809 agent_sub: agent_sub.to_string(),
1810 error: error.map(|s| s.to_string()),
1811 exit_code: None,
1812 };
1813 let _ = crate::core::audit::append(&entry);
1814}
1815
1816const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1819
1820## Available Tools
1821{tools}
1822
1823{skills_section}
1824
1825Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1826
1827- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1828- If multiple steps are needed, walk through them briefly in order
1829- Mention important gotchas or parameter choices that matter
1830- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1831
1832Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1833
1834async fn build_remote_skillati_section(keyring: &Keyring, query: &str, limit: usize) -> String {
1835 let client = match SkillAtiClient::from_env(keyring) {
1836 Ok(Some(client)) => client,
1837 Ok(None) => return String::new(),
1838 Err(err) => {
1839 tracing::warn!(error = %err, "failed to initialize SkillATI catalog for proxy help");
1840 return String::new();
1841 }
1842 };
1843
1844 let catalog = match client.catalog().await {
1845 Ok(catalog) => catalog,
1846 Err(err) => {
1847 tracing::warn!(error = %err, "failed to load SkillATI catalog for proxy help");
1848 return String::new();
1849 }
1850 };
1851
1852 let matched = SkillAtiClient::filter_catalog(&catalog, query, limit);
1853 if matched.is_empty() {
1854 return String::new();
1855 }
1856
1857 render_remote_skillati_section(&matched, catalog.len())
1858}
1859
1860fn render_remote_skillati_section(skills: &[RemoteSkillMeta], total_catalog: usize) -> String {
1861 let mut section = String::from("## Remote Skills Available Via SkillATI\n\n");
1862 section.push_str(
1863 "These skills are available remotely from the SkillATI registry. They are not installed locally. Activate one on demand with `ati skillati read <name>`, inspect bundled paths with `ati skillati resources <name>`, and fetch specific files with `ati skillati cat <name> <path>`.\n\n",
1864 );
1865
1866 for skill in skills {
1867 section.push_str(&format!("- **{}**: {}\n", skill.name, skill.description));
1868 }
1869
1870 if total_catalog > skills.len() {
1871 section.push_str(&format!(
1872 "\nOnly the most relevant {} remote skills are shown here.\n",
1873 skills.len()
1874 ));
1875 }
1876
1877 section
1878}
1879
1880fn merge_help_skill_sections(sections: &[String]) -> String {
1881 sections
1882 .iter()
1883 .filter_map(|section| {
1884 let trimmed = section.trim();
1885 if trimmed.is_empty() {
1886 None
1887 } else {
1888 Some(trimmed.to_string())
1889 }
1890 })
1891 .collect::<Vec<_>>()
1892 .join("\n\n")
1893}
1894
1895fn build_tool_context(
1896 tools: &[(
1897 &crate::core::manifest::Provider,
1898 &crate::core::manifest::Tool,
1899 )],
1900) -> String {
1901 let mut summaries = Vec::new();
1902 for (provider, tool) in tools {
1903 let mut summary = if let Some(cat) = &provider.category {
1904 format!(
1905 "- **{}** (provider: {}, category: {}): {}",
1906 tool.name, provider.name, cat, tool.description
1907 )
1908 } else {
1909 format!(
1910 "- **{}** (provider: {}): {}",
1911 tool.name, provider.name, tool.description
1912 )
1913 };
1914 if !tool.tags.is_empty() {
1915 summary.push_str(&format!("\n Tags: {}", tool.tags.join(", ")));
1916 }
1917 if provider.is_cli() && tool.input_schema.is_none() {
1919 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1920 summary.push_str(&format!(
1921 "\n Usage: `ati run {} -- <args>` (passthrough to `{}`)",
1922 tool.name, cmd
1923 ));
1924 } else if let Some(schema) = &tool.input_schema {
1925 if let Some(props) = schema.get("properties") {
1926 if let Some(obj) = props.as_object() {
1927 let params: Vec<String> = obj
1928 .iter()
1929 .filter(|(_, v)| {
1930 v.get("x-ati-param-location").is_none()
1931 || v.get("description").is_some()
1932 })
1933 .map(|(k, v)| {
1934 let type_str =
1935 v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1936 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1937 format!(" --{k} ({type_str}): {desc}")
1938 })
1939 .collect();
1940 if !params.is_empty() {
1941 summary.push_str("\n Parameters:\n");
1942 summary.push_str(¶ms.join("\n"));
1943 }
1944 }
1945 }
1946 }
1947 summaries.push(summary);
1948 }
1949 summaries.join("\n\n")
1950}
1951
1952fn build_scoped_prompt(
1956 scope_name: &str,
1957 visible_tools: &[(&Provider, &Tool)],
1958 skills_section: &str,
1959) -> Option<String> {
1960 if let Some((provider, tool)) = visible_tools
1962 .iter()
1963 .find(|(_, tool)| tool.name == scope_name)
1964 {
1965 let mut details = format!(
1966 "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1967 tool.name, provider.name, provider.handler, tool.description
1968 );
1969 if let Some(cat) = &provider.category {
1970 details.push_str(&format!("**Category**: {}\n", cat));
1971 }
1972 if provider.is_cli() {
1973 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1974 details.push_str(&format!(
1975 "\n**Usage**: `ati run {} -- <args>` (passthrough to `{}`)\n",
1976 tool.name, cmd
1977 ));
1978 } else if let Some(schema) = &tool.input_schema {
1979 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1980 let required: Vec<String> = schema
1981 .get("required")
1982 .and_then(|r| r.as_array())
1983 .map(|arr| {
1984 arr.iter()
1985 .filter_map(|v| v.as_str().map(|s| s.to_string()))
1986 .collect()
1987 })
1988 .unwrap_or_default();
1989 details.push_str("\n**Parameters**:\n");
1990 for (key, val) in props {
1991 let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1992 let desc = val
1993 .get("description")
1994 .and_then(|d| d.as_str())
1995 .unwrap_or("");
1996 let req = if required.contains(key) {
1997 " **(required)**"
1998 } else {
1999 ""
2000 };
2001 details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
2002 }
2003 }
2004 }
2005
2006 let prompt = format!(
2007 "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
2008 ## Tool Details\n{}\n\n{}\n\n\
2009 Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
2010 tool.name, details, skills_section
2011 );
2012 return Some(prompt);
2013 }
2014
2015 let tools: Vec<(&Provider, &Tool)> = visible_tools
2017 .iter()
2018 .copied()
2019 .filter(|(provider, _)| provider.name == scope_name)
2020 .collect();
2021 if !tools.is_empty() {
2022 let tools_context = build_tool_context(&tools);
2023 let prompt = format!(
2024 "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
2025 ## Tools in provider `{}`\n{}\n\n{}\n\n\
2026 Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
2027 scope_name, scope_name, tools_context, skills_section
2028 );
2029 return Some(prompt);
2030 }
2031
2032 None
2033}